mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 19:37:36 +08:00
Compare commits
142 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79804b6ecd | ||
|
|
76434b2f4e | ||
|
|
ec8bd4922e | ||
|
|
4ffa773fac | ||
|
|
ea8b7bc8aa | ||
|
|
39ce5646f6 | ||
|
|
5092a82739 | ||
|
|
3bba0b6d9a | ||
|
|
7a19dd503d | ||
|
|
9e6a01fefd | ||
|
|
933471b4d9 | ||
|
|
f81808d239 | ||
|
|
96832b6f7d | ||
|
|
e2eb0a84b0 | ||
|
|
c8eb2e3376 | ||
|
|
21fe5822f9 | ||
|
|
d49cc9a7a3 | ||
|
|
910d0bfae1 | ||
|
|
d6761949ca | ||
|
|
6afac1f593 | ||
|
|
4d1a270d22 | ||
|
|
a7888f5536 | ||
|
|
b9049e91cf | ||
|
|
7db56c8e77 | ||
|
|
50563cb957 | ||
|
|
18ae2299a7 | ||
|
|
7463e0aab9 | ||
|
|
c92d47bb95 | ||
|
|
0b1af7df91 | ||
|
|
a9104eb2da | ||
|
|
abbd15d5cc | ||
|
|
aadfa14d59 | ||
|
|
4cd10bbe25 | ||
|
|
1d4a6b71ab | ||
|
|
a7f830dd73 | ||
|
|
bae86ac05c | ||
|
|
a3706bfe21 | ||
|
|
91e23b8c11 | ||
|
|
37ef1c9fab | ||
|
|
6bc6f77af1 | ||
|
|
2c478ccc25 | ||
|
|
404e5492a3 | ||
|
|
d5b5d667a5 | ||
|
|
8807f02f36 | ||
|
|
269e561497 | ||
|
|
527ad81d38 | ||
|
|
972d3c18af | ||
|
|
3cbfc078fc | ||
|
|
fde6822b5c | ||
|
|
930321bcf1 | ||
|
|
c45931363a | ||
|
|
9c6491e5ee | ||
|
|
9bc248f5bc | ||
|
|
becac2fde5 | ||
|
|
1e1a103882 | ||
|
|
e5cffb7c9b | ||
|
|
e2becf7777 | ||
|
|
a6b875a242 | ||
|
|
b5e67f3df8 | ||
|
|
2093fb16a7 | ||
|
|
fc9a9d2386 | ||
|
|
5e69f78f7e | ||
|
|
6919bece77 | ||
|
|
8b003739f1 | ||
|
|
2e9229a6ad | ||
|
|
5a3e7fe8ee | ||
|
|
7b3d7e7bd6 | ||
|
|
fdd7c1864d | ||
|
|
cac5a5adff | ||
|
|
63307633c2 | ||
|
|
387dfa39ff | ||
|
|
1f797f899c | ||
|
|
092bb0a1e2 | ||
|
|
2c3399e237 | ||
|
|
835275b47f | ||
|
|
7b060ce3f9 | ||
|
|
1fb69311b0 | ||
|
|
995d1f61d2 | ||
|
|
80258e9182 | ||
|
|
bd6a32e08e | ||
|
|
5f138de75b | ||
|
|
d0b0f2209a | ||
|
|
0752698c1d | ||
|
|
9855c6b8f5 | ||
|
|
52a7c25540 | ||
|
|
fa823de6b0 | ||
|
|
f53070d8b6 | ||
|
|
7677672691 | ||
|
|
dead8fa168 | ||
|
|
c6347bea45 | ||
|
|
32bd194bfc | ||
|
|
cca48a394d | ||
|
|
a723c8ce37 | ||
|
|
327b2509f6 | ||
|
|
1dae7bd655 | ||
|
|
550a131685 | ||
|
|
0cfb8bb29f | ||
|
|
9c32420a95 | ||
|
|
867093cc88 | ||
|
|
82763f8ec5 | ||
|
|
97449065df | ||
|
|
9489783846 | ||
|
|
f91c9015bc | ||
|
|
302d86056d | ||
|
|
98bebfddaa | ||
|
|
dab20e3187 | ||
|
|
09e72f7c5f | ||
|
|
2028d85f84 | ||
|
|
ed3c0d9014 | ||
|
|
be06150990 | ||
|
|
afb3fb4a31 | ||
|
|
d66577e6c3 | ||
|
|
6a4ea5446a | ||
|
|
74e84c744a | ||
|
|
5ad2446cf3 | ||
|
|
63303bb5c0 | ||
|
|
13393b6624 | ||
|
|
b9fa11c0c3 | ||
|
|
8c6ce1f030 | ||
|
|
1d963d0f0c | ||
|
|
0ee383be27 | ||
|
|
53d09129b4 | ||
|
|
a398c6f311 | ||
|
|
4347ddd42a | ||
|
|
22cb8a6a06 | ||
|
|
7f554fd862 | ||
|
|
a82bfa8a56 | ||
|
|
95784debbf | ||
|
|
2471c5bf0f | ||
|
|
2fe6d731b8 | ||
|
|
ce881372ee | ||
|
|
171ea7c375 | ||
|
|
1e9a6f813f | ||
|
|
39a7f3b2b9 | ||
|
|
8d375a02db | ||
|
|
cac8a0a414 | ||
|
|
c89623967e | ||
|
|
92aa9c1711 | ||
|
|
71f2a58acb | ||
|
|
1f07a8a9e3 | ||
|
|
cacd21bde7 | ||
|
|
a060ec66c3 |
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
8
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -16,11 +16,13 @@ body:
|
||||
required: true
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: 登录框架
|
||||
label: 消息平台适配器
|
||||
description: "连接QQ使用的框架"
|
||||
options:
|
||||
- Mirai
|
||||
- go-cqhttp
|
||||
- yiri-mirai(Mirai)
|
||||
- Nakuru(go-cqhttp)
|
||||
- aiocqhttp(使用 OneBot 协议接入的)
|
||||
- qq-botpy(QQ官方API)
|
||||
validations:
|
||||
required: false
|
||||
- type: input
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -34,4 +34,4 @@ bard.json
|
||||
res/instance_id.json
|
||||
.DS_Store
|
||||
/data
|
||||
botpy.log
|
||||
botpy.log*
|
||||
@@ -5,6 +5,7 @@ COPY . .
|
||||
|
||||
RUN apt update \
|
||||
&& apt install gcc -y \
|
||||
&& python -m pip install -r requirements.txt
|
||||
&& python -m pip install -r requirements.txt \
|
||||
&& touch /.dockerenv
|
||||
|
||||
CMD [ "python", "main.py" ]
|
||||
36
README.md
36
README.md
@@ -2,33 +2,29 @@
|
||||
<p align="center">
|
||||
<img src="https://qchatgpt.rockchin.top/logo.png" alt="QChatGPT" width="180" />
|
||||
</p>
|
||||
|
||||
<div align="center">
|
||||
|
||||
# QChatGPT
|
||||
|
||||
<a href="https://trendshift.io/repositories/6187" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6187" alt="RockChinQ%2FQChatGPT | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||
|
||||
[](https://github.com/RockChinQ/QChatGPT/releases/latest)
|
||||
<a href="https://hub.docker.com/repository/docker/rockchin/qchatgpt">
|
||||
<img src="https://img.shields.io/docker/pulls/rockchin/qchatgpt?color=blue" alt="docker pull">
|
||||
</a>
|
||||

|
||||
<a href="https://codecov.io/gh/RockChinQ/QChatGPT" >
|
||||
<img src="https://codecov.io/gh/RockChinQ/QChatGPT/graph/badge.svg?token=pjxYIL2kbC"/>
|
||||
</a>
|
||||

|
||||

|
||||
<br/>
|
||||
<img src="https://img.shields.io/badge/python-3.9 | 3.10 | 3.11-blue.svg" alt="python">
|
||||
<img src="https://img.shields.io/badge/python-3.10 | 3.11 | 3.12-blue.svg" alt="python">
|
||||
<a href="http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=66-aWvn8cbP4c1ut_1YYkvvGVeEtyTH8&authKey=pTaKBK5C%2B8dFzQ4XlENf6MHTCLaHnlKcCRx7c14EeVVlpX2nRSaS8lJm8YeM4mCU&noverify=0&group_code=195992197">
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/%E5%AE%98%E6%96%B9%E7%BE%A4-195992197-purple">
|
||||
</a>
|
||||
<a href="http://qm.qq.com/cgi-bin/qm/qr?_wv=1027&k=nC80H57wmKPwRDLFeQrDDjVl81XuC21P&authKey=2wTUTfoQ5v%2BD4C5zfpuR%2BSPMDqdXgDXA%2FS2wHI1NxTfWIG%2B%2FqK08dgyjMMOzhXa9&noverify=0&group_code=248432104">
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/%E7%A4%BE%E5%8C%BA%E7%BE%A4-248432104-purple">
|
||||
</a>
|
||||
<a href="https://www.bilibili.com/video/BV14h4y1w7TC">
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/%E8%A7%86%E9%A2%91%E6%95%99%E7%A8%8B-208647">
|
||||
</a>
|
||||
<a href="https://www.bilibili.com/video/BV11h4y1y74H">
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/Linux%E9%83%A8%E7%BD%B2%E8%A7%86%E9%A2%91-208647">
|
||||
<a href="https://qm.qq.com/q/1yxEaIgXMA">
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/%E7%A4%BE%E5%8C%BA%E7%BE%A4-619154800-purple">
|
||||
</a>
|
||||
<a href="https://codecov.io/gh/RockChinQ/QChatGPT" >
|
||||
<img src="https://codecov.io/gh/RockChinQ/QChatGPT/graph/badge.svg?token=pjxYIL2kbC"/>
|
||||
</a>
|
||||
|
||||
## 使用文档
|
||||
|
||||
@@ -43,7 +39,17 @@
|
||||
|
||||
<a href="https://github.com/RockChinQ/qcg-installer">安装器源码</a> |
|
||||
<a href="https://github.com/RockChinQ/qcg-tester">测试工程源码</a> |
|
||||
<a href="https://github.com/RockChinQ/qcg-center">遥测服务端源码</a> |
|
||||
<a href="https://github.com/the-lazy-me/QChatGPT-Wiki">官方文档储存库</a>
|
||||
|
||||
<img alt="回复效果(带有联网插件)" src="https://qchatgpt.rockchin.top/assets/image/QChatGPT-1211.png" width="500px"/>
|
||||
<hr/>
|
||||
|
||||
<div align="center">
|
||||
京东云4090单卡15C90G实例 <br/>
|
||||
仅需1.89/小时,包月1225元起 <br/>
|
||||
可选预装Stable Diffusion等应用,随用随停,计费透明,欢迎首选支持 <br/>
|
||||
https://3.cn/1ZOi6-Gj
|
||||
</div>
|
||||
|
||||
<img alt="回复效果(带有联网插件)" src="https://qchatgpt.rockchin.top/assets/image/QChatGPT-0516.png" width="500px"/>
|
||||
</div>
|
||||
|
||||
62
main.py
62
main.py
@@ -1,5 +1,6 @@
|
||||
import asyncio
|
||||
|
||||
# QChatGPT 终端启动入口
|
||||
# 在此层级解决依赖项检查。
|
||||
# QChatGPT/main.py
|
||||
|
||||
asciiart = r"""
|
||||
___ ___ _ _ ___ ___ _____
|
||||
@@ -11,8 +12,61 @@ asciiart = r"""
|
||||
📖文档地址: https://q.rkcn.top
|
||||
"""
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
async def main_entry():
|
||||
print(asciiart)
|
||||
|
||||
import sys
|
||||
|
||||
# 检查依赖
|
||||
|
||||
from pkg.core.bootutils import deps
|
||||
|
||||
missing_deps = await deps.check_deps()
|
||||
|
||||
if missing_deps:
|
||||
print("以下依赖包未安装,将自动安装,请完成后重启程序:")
|
||||
for dep in missing_deps:
|
||||
print("-", dep)
|
||||
await deps.install_deps(missing_deps)
|
||||
print("已自动安装缺失的依赖包,请重启程序。")
|
||||
sys.exit(0)
|
||||
|
||||
# 检查配置文件
|
||||
|
||||
from pkg.core.bootutils import files
|
||||
|
||||
generated_files = await files.generate_files()
|
||||
|
||||
if generated_files:
|
||||
print("以下文件不存在,已自动生成,请按需修改配置文件后重启:")
|
||||
for file in generated_files:
|
||||
print("-", file)
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
from pkg.core import boot
|
||||
asyncio.run(boot.main())
|
||||
await boot.main()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import os
|
||||
|
||||
# 检查本目录是否有main.py,且包含QChatGPT字符串
|
||||
invalid_pwd = False
|
||||
|
||||
if not os.path.exists('main.py'):
|
||||
invalid_pwd = True
|
||||
else:
|
||||
with open('main.py', 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
if "QChatGPT/main.py" not in content:
|
||||
invalid_pwd = True
|
||||
if invalid_pwd:
|
||||
print("请在QChatGPT项目根目录下以命令形式运行此程序。")
|
||||
input("按任意键退出...")
|
||||
exit(0)
|
||||
|
||||
import asyncio
|
||||
|
||||
asyncio.run(main_entry())
|
||||
|
||||
@@ -34,6 +34,9 @@ class APIGroup(metaclass=abc.ABCMeta):
|
||||
headers: dict = {},
|
||||
**kwargs
|
||||
):
|
||||
"""
|
||||
执行请求
|
||||
"""
|
||||
self._runtime_info['account_id'] = "-1"
|
||||
|
||||
url = self.prefix + path
|
||||
|
||||
@@ -9,8 +9,6 @@ from .groups import plugin
|
||||
from ...core import app
|
||||
|
||||
|
||||
BACKEND_URL = "https://api.qchatgpt.rockchin.top/api/v2"
|
||||
|
||||
class V2CenterAPI:
|
||||
"""中央服务器 v2 API 交互类"""
|
||||
|
||||
@@ -23,7 +21,7 @@ class V2CenterAPI:
|
||||
plugin: plugin.V2PluginDataAPI = None
|
||||
"""插件 API 组"""
|
||||
|
||||
def __init__(self, ap: app.Application, basic_info: dict = None, runtime_info: dict = None):
|
||||
def __init__(self, ap: app.Application, backend_url: str, basic_info: dict = None, runtime_info: dict = None):
|
||||
"""初始化"""
|
||||
|
||||
logging.debug("basic_info: %s, runtime_info: %s", basic_info, runtime_info)
|
||||
@@ -31,7 +29,7 @@ class V2CenterAPI:
|
||||
apigroup.APIGroup._basic_info = basic_info
|
||||
apigroup.APIGroup._runtime_info = runtime_info
|
||||
|
||||
self.main = main.V2MainDataAPI(BACKEND_URL, ap)
|
||||
self.usage = usage.V2UsageDataAPI(BACKEND_URL, ap)
|
||||
self.plugin = plugin.V2PluginDataAPI(BACKEND_URL, ap)
|
||||
self.main = main.V2MainDataAPI(backend_url, ap)
|
||||
self.usage = usage.V2UsageDataAPI(backend_url, ap)
|
||||
self.plugin = plugin.V2PluginDataAPI(backend_url, ap)
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
# 实例 识别码 控制
|
||||
|
||||
import os
|
||||
import uuid
|
||||
import json
|
||||
|
||||
@@ -7,6 +7,7 @@ from ..provider import entities as llm_entities
|
||||
from . import entities, operator, errors
|
||||
from ..config import manager as cfg_mgr
|
||||
|
||||
# 引入所有算子以便注册
|
||||
from .operators import func, plugin, default, reset, list as list_cmd, last, next, delc, resend, prompt, cmd, help, version, update
|
||||
|
||||
|
||||
@@ -17,6 +18,9 @@ class CommandManager:
|
||||
ap: app.Application
|
||||
|
||||
cmd_list: list[operator.CommandOperator]
|
||||
"""
|
||||
运行时命令列表,扁平存储,各个对象包含对应的子节点引用
|
||||
"""
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
@@ -60,7 +64,7 @@ class CommandManager:
|
||||
"""
|
||||
|
||||
found = False
|
||||
if len(context.crt_params) > 0:
|
||||
if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
|
||||
for oper in operator_list:
|
||||
if (context.crt_params[0] == oper.name \
|
||||
or context.crt_params[0] in oper.alias) \
|
||||
@@ -78,7 +82,7 @@ class CommandManager:
|
||||
yield ret
|
||||
break
|
||||
|
||||
if not found:
|
||||
if not found: # 如果下一个参数未在此节点的子节点中找到,则执行此节点或者报错
|
||||
if operator is None:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandNotFoundError(context.crt_params[0])
|
||||
|
||||
@@ -10,33 +10,67 @@ from . import errors, operator
|
||||
|
||||
|
||||
class CommandReturn(pydantic.BaseModel):
|
||||
"""命令返回值
|
||||
"""
|
||||
|
||||
text: typing.Optional[str]
|
||||
text: typing.Optional[str] = None
|
||||
"""文本
|
||||
"""
|
||||
|
||||
image: typing.Optional[mirai.Image]
|
||||
image: typing.Optional[mirai.Image] = None
|
||||
"""弃用"""
|
||||
|
||||
image_url: typing.Optional[str] = None
|
||||
"""图片链接
|
||||
"""
|
||||
|
||||
error: typing.Optional[errors.CommandError]= None
|
||||
"""错误
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ExecuteContext(pydantic.BaseModel):
|
||||
"""单次命令执行上下文
|
||||
"""
|
||||
|
||||
query: core_entities.Query
|
||||
"""本次消息的请求对象"""
|
||||
|
||||
session: core_entities.Session
|
||||
"""本次消息所属的会话对象"""
|
||||
|
||||
command_text: str
|
||||
"""命令完整文本"""
|
||||
|
||||
command: str
|
||||
"""命令名称"""
|
||||
|
||||
crt_command: str
|
||||
"""当前命令
|
||||
|
||||
多级命令中crt_command为当前命令,command为根命令。
|
||||
例如:!plugin on Webwlkr
|
||||
处理到plugin时,command为plugin,crt_command为plugin
|
||||
处理到on时,command为plugin,crt_command为on
|
||||
"""
|
||||
|
||||
params: list[str]
|
||||
"""命令参数
|
||||
|
||||
整个命令以空格分割后的参数列表
|
||||
"""
|
||||
|
||||
crt_params: list[str]
|
||||
"""当前命令参数
|
||||
|
||||
多级命令中crt_params为当前命令参数,params为根命令参数。
|
||||
例如:!plugin on Webwlkr
|
||||
处理到plugin时,params为['on', 'Webwlkr'],crt_params为['on', 'Webwlkr']
|
||||
处理到on时,params为['on', 'Webwlkr'],crt_params为['Webwlkr']
|
||||
"""
|
||||
|
||||
privilege: int
|
||||
"""发起人权限"""
|
||||
|
||||
@@ -8,17 +8,34 @@ from . import entities
|
||||
|
||||
|
||||
preregistered_operators: list[typing.Type[CommandOperator]] = []
|
||||
"""预注册命令算子列表。在初始化时,所有算子类会被注册到此列表中。"""
|
||||
|
||||
|
||||
def operator_class(
|
||||
name: str,
|
||||
help: str,
|
||||
help: str = "",
|
||||
usage: str = None,
|
||||
alias: list[str] = [],
|
||||
privilege: int=1, # 1为普通用户,2为管理员
|
||||
parent_class: typing.Type[CommandOperator] = None
|
||||
) -> typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]:
|
||||
"""命令类装饰器
|
||||
|
||||
Args:
|
||||
name (str): 名称
|
||||
help (str, optional): 帮助信息. Defaults to "".
|
||||
usage (str, optional): 使用说明. Defaults to None.
|
||||
alias (list[str], optional): 别名. Defaults to [].
|
||||
privilege (int, optional): 权限,1为普通用户可用,2为仅管理员可用. Defaults to 1.
|
||||
parent_class (typing.Type[CommandOperator], optional): 父节点,若为None则为顶级命令. Defaults to None.
|
||||
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[CommandOperator]], typing.Type[CommandOperator]]: 装饰器
|
||||
"""
|
||||
|
||||
def decorator(cls: typing.Type[CommandOperator]) -> typing.Type[CommandOperator]:
|
||||
assert issubclass(cls, CommandOperator)
|
||||
|
||||
cls.name = name
|
||||
cls.alias = alias
|
||||
cls.help = help
|
||||
@@ -34,7 +51,12 @@ def operator_class(
|
||||
|
||||
|
||||
class CommandOperator(metaclass=abc.ABCMeta):
|
||||
"""命令算子
|
||||
"""命令算子抽象类
|
||||
|
||||
以下的参数均不需要在子类中设置,只需要在使用装饰器注册类时作为参数传递即可。
|
||||
命令支持级联,即一个命令可以有多个子命令,子命令可以有子命令,以此类推。
|
||||
处理命令时,若有子命令,会以当前参数列表的第一个参数去匹配子命令,若匹配成功,则转移到子命令中执行。
|
||||
若没有匹配成功或没有子命令,则执行当前命令。
|
||||
"""
|
||||
|
||||
ap: app.Application
|
||||
@@ -43,7 +65,8 @@ class CommandOperator(metaclass=abc.ABCMeta):
|
||||
"""名称,搜索到时若符合则使用"""
|
||||
|
||||
path: str
|
||||
"""路径,所有父节点的name的连接,用于定义命令权限"""
|
||||
"""路径,所有父节点的name的连接,用于定义命令权限,由管理器在初始化时自动设置。
|
||||
"""
|
||||
|
||||
alias: list[str]
|
||||
"""同name"""
|
||||
@@ -52,6 +75,7 @@ class CommandOperator(metaclass=abc.ABCMeta):
|
||||
"""此节点的帮助信息"""
|
||||
|
||||
usage: str = None
|
||||
"""用法"""
|
||||
|
||||
parent_class: typing.Union[typing.Type[CommandOperator], None] = None
|
||||
"""父节点类。标记以供管理器在初始化时编织父子关系。"""
|
||||
@@ -75,4 +99,15 @@ class CommandOperator(metaclass=abc.ABCMeta):
|
||||
self,
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
"""实现此方法以执行命令
|
||||
|
||||
支持多次yield以返回多个结果。
|
||||
例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。
|
||||
|
||||
Args:
|
||||
context (entities.ExecuteContext): 命令执行上下文
|
||||
|
||||
Yields:
|
||||
entities.CommandReturn: 命令返回封装
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -24,7 +24,7 @@ class DefaultOperator(operator.CommandOperator):
|
||||
|
||||
content = ""
|
||||
for msg in prompt.messages:
|
||||
content += f" {msg.role}: {msg.content}"
|
||||
content += f" {msg.readable_str()}\n"
|
||||
|
||||
reply_str += f"名称: {prompt.name}\n内容: \n{content}\n\n"
|
||||
|
||||
@@ -45,18 +45,18 @@ class DefaultSetOperator(operator.CommandOperator):
|
||||
context: entities.ExecuteContext
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称'))
|
||||
else:
|
||||
prompt_name = context.crt_params[0]
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供情景预设名称'))
|
||||
else:
|
||||
prompt_name = context.crt_params[0]
|
||||
|
||||
try:
|
||||
prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name)
|
||||
if prompt is None:
|
||||
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name)))
|
||||
else:
|
||||
context.session.use_prompt_name = prompt.name
|
||||
yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e)))
|
||||
try:
|
||||
prompt = await self.ap.prompt_mgr.get_prompt_by_prefix(prompt_name)
|
||||
if prompt is None:
|
||||
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: 未找到情景预设 {}".format(prompt_name)))
|
||||
else:
|
||||
context.session.use_prompt_name = prompt.name
|
||||
yield entities.CommandReturn(text=f"已设置当前会话默认情景预设为 {prompt_name}, !reset 后生效")
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError("设置当前会话默认情景预设失败: "+str(e)))
|
||||
|
||||
@@ -30,7 +30,7 @@ class LastOperator(operator.CommandOperator):
|
||||
context.session.using_conversation = context.session.conversations[index-1]
|
||||
time_str = context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}")
|
||||
yield entities.CommandReturn(text=f"已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}")
|
||||
return
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
||||
@@ -42,7 +42,7 @@ class ListOperator(operator.CommandOperator):
|
||||
using_conv_index = index
|
||||
|
||||
if index >= page * record_per_page and index < (page + 1) * record_per_page:
|
||||
content += f"{index} {time_str}: {conv.messages[0].content if len(conv.messages) > 0 else '无内容'}\n"
|
||||
content += f"{index} {time_str}: {conv.messages[0].readable_str() if len(conv.messages) > 0 else '无内容'}\n"
|
||||
index += 1
|
||||
|
||||
if content == '':
|
||||
@@ -51,6 +51,6 @@ class ListOperator(operator.CommandOperator):
|
||||
if context.session.using_conversation is None:
|
||||
content += "\n当前处于新会话"
|
||||
else:
|
||||
content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].content if len(context.session.using_conversation.messages) > 0 else '无内容'}"
|
||||
content += f"\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else '无内容'}"
|
||||
|
||||
yield entities.CommandReturn(text=f"第 {page + 1} 页 (时间倒序):\n{content}")
|
||||
|
||||
@@ -8,40 +8,52 @@ from .. import model as file_model
|
||||
class JSONConfigFile(file_model.ConfigFile):
|
||||
"""JSON配置文件"""
|
||||
|
||||
config_file_name: str = None
|
||||
"""配置文件名"""
|
||||
|
||||
template_file_name: str = None
|
||||
"""模板文件名"""
|
||||
|
||||
def __init__(self, config_file_name: str, template_file_name: str) -> None:
|
||||
def __init__(
|
||||
self, config_file_name: str, template_file_name: str = None, template_data: dict = None
|
||||
) -> None:
|
||||
self.config_file_name = config_file_name
|
||||
self.template_file_name = template_file_name
|
||||
self.template_data = template_data
|
||||
|
||||
def exists(self) -> bool:
|
||||
return os.path.exists(self.config_file_name)
|
||||
|
||||
async def create(self):
|
||||
shutil.copyfile(self.template_file_name, self.config_file_name)
|
||||
if self.template_file_name is not None:
|
||||
shutil.copyfile(self.template_file_name, self.config_file_name)
|
||||
elif self.template_data is not None:
|
||||
with open(self.config_file_name, "w", encoding="utf-8") as f:
|
||||
json.dump(self.template_data, f, indent=4, ensure_ascii=False)
|
||||
else:
|
||||
raise ValueError("template_file_name or template_data must be provided")
|
||||
|
||||
async def load(self) -> dict:
|
||||
async def load(self, completion: bool=True) -> dict:
|
||||
|
||||
if not self.exists():
|
||||
await self.create()
|
||||
|
||||
with open(self.config_file_name, 'r', encoding='utf-8') as f:
|
||||
cfg = json.load(f)
|
||||
if self.template_file_name is not None:
|
||||
with open(self.template_file_name, "r", encoding="utf-8") as f:
|
||||
self.template_data = json.load(f)
|
||||
|
||||
# 从模板文件中进行补全
|
||||
with open(self.template_file_name, 'r', encoding='utf-8') as f:
|
||||
template_cfg = json.load(f)
|
||||
with open(self.config_file_name, "r", encoding="utf-8") as f:
|
||||
try:
|
||||
cfg = json.load(f)
|
||||
except json.JSONDecodeError as e:
|
||||
raise Exception(f"配置文件 {self.config_file_name} 语法错误: {e}")
|
||||
|
||||
for key in template_cfg:
|
||||
if key not in cfg:
|
||||
cfg[key] = template_cfg[key]
|
||||
if completion:
|
||||
|
||||
for key in self.template_data:
|
||||
if key not in cfg:
|
||||
cfg[key] = self.template_data[key]
|
||||
|
||||
return cfg
|
||||
|
||||
async def save(self, cfg: dict):
|
||||
with open(self.config_file_name, 'w', encoding='utf-8') as f:
|
||||
with open(self.config_file_name, "w", encoding="utf-8") as f:
|
||||
json.dump(cfg, f, indent=4, ensure_ascii=False)
|
||||
|
||||
def save_sync(self, cfg: dict):
|
||||
with open(self.config_file_name, "w", encoding="utf-8") as f:
|
||||
json.dump(cfg, f, indent=4, ensure_ascii=False)
|
||||
@@ -25,7 +25,7 @@ class PythonModuleConfigFile(file_model.ConfigFile):
|
||||
async def create(self):
|
||||
shutil.copyfile(self.template_file_name, self.config_file_name)
|
||||
|
||||
async def load(self) -> dict:
|
||||
async def load(self, completion: bool=True) -> dict:
|
||||
module_name = os.path.splitext(os.path.basename(self.config_file_name))[0]
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
@@ -43,20 +43,24 @@ class PythonModuleConfigFile(file_model.ConfigFile):
|
||||
cfg[key] = getattr(module, key)
|
||||
|
||||
# 从模板模块文件中进行补全
|
||||
module_name = os.path.splitext(os.path.basename(self.template_file_name))[0]
|
||||
module = importlib.import_module(module_name)
|
||||
if completion:
|
||||
module_name = os.path.splitext(os.path.basename(self.template_file_name))[0]
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
for key in dir(module):
|
||||
if key.startswith('__'):
|
||||
continue
|
||||
for key in dir(module):
|
||||
if key.startswith('__'):
|
||||
continue
|
||||
|
||||
if not isinstance(getattr(module, key), allowed_types):
|
||||
continue
|
||||
if not isinstance(getattr(module, key), allowed_types):
|
||||
continue
|
||||
|
||||
if key not in cfg:
|
||||
cfg[key] = getattr(module, key)
|
||||
if key not in cfg:
|
||||
cfg[key] = getattr(module, key)
|
||||
|
||||
return cfg
|
||||
|
||||
async def save(self, data: dict):
|
||||
logging.warning('Python模块配置文件不支持保存')
|
||||
|
||||
def save_sync(self, data: dict):
|
||||
logging.warning('Python模块配置文件不支持保存')
|
||||
@@ -20,14 +20,17 @@ class ConfigManager:
|
||||
self.file = cfg_file
|
||||
self.data = {}
|
||||
|
||||
async def load_config(self):
|
||||
self.data = await self.file.load()
|
||||
async def load_config(self, completion: bool=True):
|
||||
self.data = await self.file.load(completion=completion)
|
||||
|
||||
async def dump_config(self):
|
||||
await self.file.save(self.data)
|
||||
|
||||
def dump_config_sync(self):
|
||||
self.file.save_sync(self.data)
|
||||
|
||||
async def load_python_module_config(config_name: str, template_name: str) -> ConfigManager:
|
||||
|
||||
async def load_python_module_config(config_name: str, template_name: str, completion: bool=True) -> ConfigManager:
|
||||
"""加载Python模块配置文件"""
|
||||
cfg_inst = pymodule.PythonModuleConfigFile(
|
||||
config_name,
|
||||
@@ -35,19 +38,20 @@ async def load_python_module_config(config_name: str, template_name: str) -> Con
|
||||
)
|
||||
|
||||
cfg_mgr = ConfigManager(cfg_inst)
|
||||
await cfg_mgr.load_config()
|
||||
await cfg_mgr.load_config(completion=completion)
|
||||
|
||||
return cfg_mgr
|
||||
|
||||
|
||||
async def load_json_config(config_name: str, template_name: str) -> ConfigManager:
|
||||
async def load_json_config(config_name: str, template_name: str=None, template_data: dict=None, completion: bool=True) -> ConfigManager:
|
||||
"""加载JSON配置文件"""
|
||||
cfg_inst = json_file.JSONConfigFile(
|
||||
config_name,
|
||||
template_name
|
||||
template_name,
|
||||
template_data
|
||||
)
|
||||
|
||||
cfg_mgr = ConfigManager(cfg_inst)
|
||||
await cfg_mgr.load_config()
|
||||
await cfg_mgr.load_config(completion=completion)
|
||||
|
||||
return cfg_mgr
|
||||
47
pkg/config/migration.py
Normal file
47
pkg/config/migration.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ..core import app
|
||||
|
||||
|
||||
preregistered_migrations: list[typing.Type[Migration]] = []
|
||||
"""当前阶段暂不支持扩展"""
|
||||
|
||||
def migration_class(name: str, number: int):
|
||||
"""注册一个迁移
|
||||
"""
|
||||
def decorator(cls: typing.Type[Migration]) -> typing.Type[Migration]:
|
||||
cls.name = name
|
||||
cls.number = number
|
||||
preregistered_migrations.append(cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Migration(abc.ABC):
|
||||
"""一个版本的迁移
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
number: int
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
@abc.abstractmethod
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def run(self):
|
||||
"""执行迁移
|
||||
"""
|
||||
pass
|
||||
26
pkg/config/migrations/m001_sensitive_word_migration.py
Normal file
26
pkg/config/migrations/m001_sensitive_word_migration.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("sensitive-word-migration", 1)
|
||||
class SensitiveWordMigration(migration.Migration):
|
||||
"""敏感词迁移
|
||||
"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移
|
||||
"""
|
||||
return os.path.exists("data/config/sensitive-words.json") and not os.path.exists("data/metadata/sensitive-words.json")
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移
|
||||
"""
|
||||
# 移动文件
|
||||
os.rename("data/config/sensitive-words.json", "data/metadata/sensitive-words.json")
|
||||
|
||||
# 重新加载配置
|
||||
await self.ap.sensitive_meta.load_config()
|
||||
47
pkg/config/migrations/m002_openai_config_migration.py
Normal file
47
pkg/config/migrations/m002_openai_config_migration.py
Normal file
@@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("openai-config-migration", 2)
|
||||
class OpenAIConfigMigration(migration.Migration):
|
||||
"""OpenAI配置迁移
|
||||
"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移
|
||||
"""
|
||||
return 'openai-config' in self.ap.provider_cfg.data
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移
|
||||
"""
|
||||
old_openai_config = self.ap.provider_cfg.data['openai-config'].copy()
|
||||
|
||||
if 'keys' not in self.ap.provider_cfg.data:
|
||||
self.ap.provider_cfg.data['keys'] = {}
|
||||
|
||||
if 'openai' not in self.ap.provider_cfg.data['keys']:
|
||||
self.ap.provider_cfg.data['keys']['openai'] = []
|
||||
|
||||
self.ap.provider_cfg.data['keys']['openai'] = old_openai_config['api-keys']
|
||||
|
||||
self.ap.provider_cfg.data['model'] = old_openai_config['chat-completions-params']['model']
|
||||
|
||||
del old_openai_config['chat-completions-params']['model']
|
||||
|
||||
if 'requester' not in self.ap.provider_cfg.data:
|
||||
self.ap.provider_cfg.data['requester'] = {}
|
||||
|
||||
if 'openai-chat-completions' not in self.ap.provider_cfg.data['requester']:
|
||||
self.ap.provider_cfg.data['requester']['openai-chat-completions'] = {}
|
||||
|
||||
self.ap.provider_cfg.data['requester']['openai-chat-completions'] = {
|
||||
'base-url': old_openai_config['base_url'],
|
||||
'args': old_openai_config['chat-completions-params'],
|
||||
'timeout': old_openai_config['request-timeout'],
|
||||
}
|
||||
|
||||
del self.ap.provider_cfg.data['openai-config']
|
||||
|
||||
await self.ap.provider_cfg.dump_config()
|
||||
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("anthropic-requester-config-completion", 3)
|
||||
class AnthropicRequesterConfigCompletionMigration(migration.Migration):
|
||||
"""OpenAI配置迁移
|
||||
"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移
|
||||
"""
|
||||
return 'anthropic-messages' not in self.ap.provider_cfg.data['requester'] \
|
||||
or 'anthropic' not in self.ap.provider_cfg.data['keys']
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移
|
||||
"""
|
||||
if 'anthropic-messages' not in self.ap.provider_cfg.data['requester']:
|
||||
self.ap.provider_cfg.data['requester']['anthropic-messages'] = {
|
||||
'base-url': 'https://api.anthropic.com',
|
||||
'args': {
|
||||
'max_tokens': 1024
|
||||
},
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
if 'anthropic' not in self.ap.provider_cfg.data['keys']:
|
||||
self.ap.provider_cfg.data['keys']['anthropic'] = []
|
||||
|
||||
await self.ap.provider_cfg.dump_config()
|
||||
30
pkg/config/migrations/m004_moonshot_cfg_completion.py
Normal file
30
pkg/config/migrations/m004_moonshot_cfg_completion.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("moonshot-config-completion", 4)
|
||||
class MoonshotConfigCompletionMigration(migration.Migration):
|
||||
"""OpenAI配置迁移
|
||||
"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移
|
||||
"""
|
||||
return 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester'] \
|
||||
or 'moonshot' not in self.ap.provider_cfg.data['keys']
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移
|
||||
"""
|
||||
if 'moonshot-chat-completions' not in self.ap.provider_cfg.data['requester']:
|
||||
self.ap.provider_cfg.data['requester']['moonshot-chat-completions'] = {
|
||||
'base-url': 'https://api.moonshot.cn/v1',
|
||||
'args': {},
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
if 'moonshot' not in self.ap.provider_cfg.data['keys']:
|
||||
self.ap.provider_cfg.data['keys']['moonshot'] = []
|
||||
|
||||
await self.ap.provider_cfg.dump_config()
|
||||
30
pkg/config/migrations/m005_deepseek_cfg_completion.py
Normal file
30
pkg/config/migrations/m005_deepseek_cfg_completion.py
Normal file
@@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("deepseek-config-completion", 5)
|
||||
class DeepseekConfigCompletionMigration(migration.Migration):
|
||||
"""OpenAI配置迁移
|
||||
"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移
|
||||
"""
|
||||
return 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester'] \
|
||||
or 'deepseek' not in self.ap.provider_cfg.data['keys']
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移
|
||||
"""
|
||||
if 'deepseek-chat-completions' not in self.ap.provider_cfg.data['requester']:
|
||||
self.ap.provider_cfg.data['requester']['deepseek-chat-completions'] = {
|
||||
'base-url': 'https://api.deepseek.com',
|
||||
'args': {},
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
if 'deepseek' not in self.ap.provider_cfg.data['keys']:
|
||||
self.ap.provider_cfg.data['keys']['deepseek'] = []
|
||||
|
||||
await self.ap.provider_cfg.dump_config()
|
||||
19
pkg/config/migrations/m006_vision_config.py
Normal file
19
pkg/config/migrations/m006_vision_config.py
Normal file
@@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("vision-config", 6)
|
||||
class VisionConfigMigration(migration.Migration):
|
||||
"""迁移"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return "enable-vision" not in self.ap.provider_cfg.data
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
if "enable-vision" not in self.ap.provider_cfg.data:
|
||||
self.ap.provider_cfg.data["enable-vision"] = False
|
||||
|
||||
await self.ap.provider_cfg.dump_config()
|
||||
20
pkg/config/migrations/m007_qcg_center_url.py
Normal file
20
pkg/config/migrations/m007_qcg_center_url.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("qcg-center-url-config", 7)
|
||||
class QCGCenterURLConfigMigration(migration.Migration):
|
||||
"""迁移"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return "qcg-center-url" not in self.ap.system_cfg.data
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
|
||||
if "qcg-center-url" not in self.ap.system_cfg.data:
|
||||
self.ap.system_cfg.data["qcg-center-url"] = "https://api.qchatgpt.rockchin.top/api/v2"
|
||||
|
||||
await self.ap.system_cfg.dump_config()
|
||||
29
pkg/config/migrations/m008_ad_fixwin_config_migrate.py
Normal file
29
pkg/config/migrations/m008_ad_fixwin_config_migrate.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("ad-fixwin-cfg-migration", 8)
|
||||
class AdFixwinConfigMigration(migration.Migration):
|
||||
"""迁移"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return isinstance(
|
||||
self.ap.pipeline_cfg.data["rate-limit"]["fixwin"]["default"],
|
||||
int
|
||||
)
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
|
||||
for session_name in self.ap.pipeline_cfg.data["rate-limit"]["fixwin"]:
|
||||
|
||||
temp_dict = {
|
||||
"window-size": 60,
|
||||
"limit": self.ap.pipeline_cfg.data["rate-limit"]["fixwin"][session_name]
|
||||
}
|
||||
|
||||
self.ap.pipeline_cfg.data["rate-limit"]["fixwin"][session_name] = temp_dict
|
||||
|
||||
await self.ap.pipeline_cfg.dump_config()
|
||||
24
pkg/config/migrations/m009_msg_truncator_cfg.py
Normal file
24
pkg/config/migrations/m009_msg_truncator_cfg.py
Normal file
@@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class("msg-truncator-cfg-migration", 9)
|
||||
class MsgTruncatorConfigMigration(migration.Migration):
|
||||
"""迁移"""
|
||||
|
||||
async def need_migrate(self) -> bool:
|
||||
"""判断当前环境是否需要运行此迁移"""
|
||||
return 'msg-truncate' not in self.ap.pipeline_cfg.data
|
||||
|
||||
async def run(self):
|
||||
"""执行迁移"""
|
||||
|
||||
self.ap.pipeline_cfg.data['msg-truncate'] = {
|
||||
'method': 'round',
|
||||
'round': {
|
||||
'max-round': 10
|
||||
}
|
||||
}
|
||||
|
||||
await self.ap.pipeline_cfg.dump_config()
|
||||
@@ -10,6 +10,9 @@ class ConfigFile(metaclass=abc.ABCMeta):
|
||||
template_file_name: str = None
|
||||
"""模板文件名"""
|
||||
|
||||
template_data: dict = None
|
||||
"""模板数据"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def exists(self) -> bool:
|
||||
pass
|
||||
@@ -19,9 +22,13 @@ class ConfigFile(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def load(self) -> dict:
|
||||
async def load(self, completion: bool=True) -> dict:
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def save(self, data: dict):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def save_sync(self, data: dict):
|
||||
pass
|
||||
|
||||
@@ -4,24 +4,24 @@ import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
import aioconsole
|
||||
|
||||
from ..platform import manager as im_mgr
|
||||
from ..provider.session import sessionmgr as llm_session_mgr
|
||||
from ..provider.requester import modelmgr as llm_model_mgr
|
||||
from ..provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ..provider.tools import toolmgr as llm_tool_mgr
|
||||
from ..config import manager as config_mgr
|
||||
from ..audit.center import v2 as center_mgr
|
||||
from ..command import cmdmgr
|
||||
from ..plugin import manager as plugin_mgr
|
||||
from . import pool, controller
|
||||
from ..pipeline import stagemgr
|
||||
from ..pipeline import pool
|
||||
from ..pipeline import controller, stagemgr
|
||||
from ..utils import version as version_mgr, proxy as proxy_mgr
|
||||
|
||||
|
||||
class Application:
|
||||
im_mgr: im_mgr.PlatformManager = None
|
||||
"""运行时应用对象和上下文"""
|
||||
|
||||
platform_mgr: im_mgr.PlatformManager = None
|
||||
|
||||
cmd_mgr: cmdmgr.CommandManager = None
|
||||
|
||||
@@ -33,6 +33,8 @@ class Application:
|
||||
|
||||
tool_mgr: llm_tool_mgr.ToolManager = None
|
||||
|
||||
# ======= 配置管理器 =======
|
||||
|
||||
command_cfg: config_mgr.ConfigManager = None
|
||||
|
||||
pipeline_cfg: config_mgr.ConfigManager = None
|
||||
@@ -43,6 +45,18 @@ class Application:
|
||||
|
||||
system_cfg: config_mgr.ConfigManager = None
|
||||
|
||||
# ======= 元数据配置管理器 =======
|
||||
|
||||
sensitive_meta: config_mgr.ConfigManager = None
|
||||
|
||||
adapter_qq_botpy_meta: config_mgr.ConfigManager = None
|
||||
|
||||
plugin_setting_meta: config_mgr.ConfigManager = None
|
||||
|
||||
llm_models_meta: config_mgr.ConfigManager = None
|
||||
|
||||
# =========================
|
||||
|
||||
ctr_mgr: center_mgr.V2CenterAPI = None
|
||||
|
||||
plugin_mgr: plugin_mgr.PluginManager = None
|
||||
@@ -66,27 +80,18 @@ class Application:
|
||||
pass
|
||||
|
||||
async def run(self):
|
||||
await self.plugin_mgr.load_plugins()
|
||||
await self.plugin_mgr.initialize_plugins()
|
||||
|
||||
tasks = []
|
||||
|
||||
try:
|
||||
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(self.im_mgr.run()),
|
||||
asyncio.create_task(self.platform_mgr.run()),
|
||||
asyncio.create_task(self.ctrl.run())
|
||||
]
|
||||
|
||||
# async def interrupt(tasks):
|
||||
# await asyncio.sleep(1.5)
|
||||
# while await aioconsole.ainput("使用 ctrl+c 或 'exit' 退出程序 > ") != 'exit':
|
||||
# pass
|
||||
# for task in tasks:
|
||||
# task.cancel()
|
||||
|
||||
# await interrupt(tasks)
|
||||
# 挂信号处理
|
||||
|
||||
import signal
|
||||
|
||||
|
||||
148
pkg/core/boot.py
148
pkg/core/boot.py
@@ -1,143 +1,36 @@
|
||||
from __future__ import print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from .bootutils import files
|
||||
from .bootutils import deps
|
||||
from .bootutils import log
|
||||
from .bootutils import config
|
||||
import traceback
|
||||
|
||||
from . import app
|
||||
from . import pool
|
||||
from . import controller
|
||||
from ..pipeline import stagemgr
|
||||
from ..audit import identifier
|
||||
from ..provider.session import sessionmgr as llm_session_mgr
|
||||
from ..provider.requester import modelmgr as llm_model_mgr
|
||||
from ..provider.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ..provider.tools import toolmgr as llm_tool_mgr
|
||||
from ..platform import manager as im_mgr
|
||||
from ..command import cmdmgr
|
||||
from ..plugin import manager as plugin_mgr
|
||||
from ..audit.center import v2 as center_v2
|
||||
from ..utils import version, proxy, announce
|
||||
from . import stage
|
||||
|
||||
use_override = False
|
||||
# 引入启动阶段实现以便注册
|
||||
from .stages import load_config, setup_logger, build_app, migrate
|
||||
|
||||
|
||||
stage_order = [
|
||||
"LoadConfigStage",
|
||||
"MigrationStage",
|
||||
"SetupLoggerStage",
|
||||
"BuildAppStage"
|
||||
]
|
||||
|
||||
|
||||
async def make_app() -> app.Application:
|
||||
global use_override
|
||||
|
||||
generated_files = await files.generate_files()
|
||||
|
||||
if generated_files:
|
||||
print("以下文件不存在,已自动生成,请按需修改配置文件后重启:")
|
||||
for file in generated_files:
|
||||
print("-", file)
|
||||
|
||||
sys.exit(0)
|
||||
|
||||
missing_deps = await deps.check_deps()
|
||||
|
||||
if missing_deps:
|
||||
print("以下依赖包未安装,将自动安装,请完成后重启程序:")
|
||||
for dep in missing_deps:
|
||||
print("-", dep)
|
||||
await deps.install_deps(missing_deps)
|
||||
sys.exit(0)
|
||||
|
||||
qcg_logger = await log.init_logging()
|
||||
|
||||
# 生成标识符
|
||||
identifier.init()
|
||||
|
||||
# ========== 加载配置文件 ==========
|
||||
|
||||
command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json")
|
||||
pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json")
|
||||
platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json")
|
||||
provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json")
|
||||
system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json")
|
||||
|
||||
# ========== 构建应用实例 ==========
|
||||
ap = app.Application()
|
||||
ap.logger = qcg_logger
|
||||
|
||||
ap.command_cfg = command_cfg
|
||||
ap.pipeline_cfg = pipeline_cfg
|
||||
ap.platform_cfg = platform_cfg
|
||||
ap.provider_cfg = provider_cfg
|
||||
ap.system_cfg = system_cfg
|
||||
# 执行启动阶段
|
||||
for stage_name in stage_order:
|
||||
stage_cls = stage.preregistered_stages[stage_name]
|
||||
stage_inst = stage_cls()
|
||||
|
||||
proxy_mgr = proxy.ProxyManager(ap)
|
||||
await proxy_mgr.initialize()
|
||||
ap.proxy_mgr = proxy_mgr
|
||||
|
||||
ver_mgr = version.VersionManager(ap)
|
||||
await ver_mgr.initialize()
|
||||
ap.ver_mgr = ver_mgr
|
||||
|
||||
center_v2_api = center_v2.V2CenterAPI(
|
||||
ap,
|
||||
basic_info={
|
||||
"host_id": identifier.identifier["host_id"],
|
||||
"instance_id": identifier.identifier["instance_id"],
|
||||
"semantic_version": ver_mgr.get_current_version(),
|
||||
"platform": sys.platform,
|
||||
},
|
||||
runtime_info={
|
||||
"admin_id": "{}".format(system_cfg.data["admin-sessions"]),
|
||||
"msg_source": str([
|
||||
adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown'
|
||||
for adapter_cfg in platform_cfg.data['platform-adapters'] if adapter_cfg['enable']
|
||||
]),
|
||||
},
|
||||
)
|
||||
ap.ctr_mgr = center_v2_api
|
||||
|
||||
# 发送公告
|
||||
ann_mgr = announce.AnnouncementManager(ap)
|
||||
await ann_mgr.show_announcements()
|
||||
|
||||
ap.query_pool = pool.QueryPool()
|
||||
|
||||
await ap.ver_mgr.show_version_update()
|
||||
|
||||
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
|
||||
await plugin_mgr_inst.initialize()
|
||||
ap.plugin_mgr = plugin_mgr_inst
|
||||
|
||||
cmd_mgr_inst = cmdmgr.CommandManager(ap)
|
||||
await cmd_mgr_inst.initialize()
|
||||
ap.cmd_mgr = cmd_mgr_inst
|
||||
|
||||
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
|
||||
await llm_model_mgr_inst.initialize()
|
||||
ap.model_mgr = llm_model_mgr_inst
|
||||
|
||||
llm_session_mgr_inst = llm_session_mgr.SessionManager(ap)
|
||||
await llm_session_mgr_inst.initialize()
|
||||
ap.sess_mgr = llm_session_mgr_inst
|
||||
|
||||
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap)
|
||||
await llm_prompt_mgr_inst.initialize()
|
||||
ap.prompt_mgr = llm_prompt_mgr_inst
|
||||
|
||||
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
|
||||
await llm_tool_mgr_inst.initialize()
|
||||
ap.tool_mgr = llm_tool_mgr_inst
|
||||
|
||||
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
|
||||
await im_mgr_inst.initialize()
|
||||
ap.im_mgr = im_mgr_inst
|
||||
|
||||
stage_mgr = stagemgr.StageManager(ap)
|
||||
await stage_mgr.initialize()
|
||||
ap.stage_mgr = stage_mgr
|
||||
|
||||
ctrl = controller.Controller(ap)
|
||||
ap.ctrl = ctrl
|
||||
await stage_inst.run(ap)
|
||||
|
||||
await ap.initialize()
|
||||
|
||||
@@ -145,5 +38,8 @@ async def make_app() -> app.Application:
|
||||
|
||||
|
||||
async def main():
|
||||
app_inst = await make_app()
|
||||
await app_inst.run()
|
||||
try:
|
||||
app_inst = await make_app()
|
||||
await app_inst.run()
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -8,16 +8,3 @@ from ...config.impls import pymodule
|
||||
|
||||
load_python_module_config = config_mgr.load_python_module_config
|
||||
load_json_config = config_mgr.load_json_config
|
||||
|
||||
|
||||
async def override_config_manager(cfg_mgr: config_mgr.ConfigManager) -> list[str]:
|
||||
override_json = json.load(open("override.json", "r", encoding="utf-8"))
|
||||
overrided = []
|
||||
|
||||
config = cfg_mgr.data
|
||||
for key in override_json:
|
||||
if key in config:
|
||||
config[key] = override_json[key]
|
||||
overrided.append(key)
|
||||
|
||||
return overrided
|
||||
|
||||
@@ -3,14 +3,18 @@ import pip
|
||||
required_deps = {
|
||||
"requests": "requests",
|
||||
"openai": "openai",
|
||||
"anthropic": "anthropic",
|
||||
"colorlog": "colorlog",
|
||||
"mirai": "yiri-mirai-rc",
|
||||
"aiocqhttp": "aiocqhttp",
|
||||
"botpy": "qq-botpy",
|
||||
"PIL": "pillow",
|
||||
"nakuru": "nakuru-project-idk",
|
||||
"CallingGPT": "CallingGPT",
|
||||
"tiktoken": "tiktoken",
|
||||
"yaml": "pyyaml",
|
||||
"aiohttp": "aiohttp",
|
||||
"psutil": "psutil",
|
||||
"async_lru": "async-lru",
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -13,13 +13,13 @@ required_files = {
|
||||
"data/config/platform.json": "templates/platform.json",
|
||||
"data/config/provider.json": "templates/provider.json",
|
||||
"data/config/system.json": "templates/system.json",
|
||||
"data/config/sensitive-words.json": "templates/sensitive-words.json",
|
||||
"data/scenario/default.json": "templates/scenario-template.json",
|
||||
}
|
||||
|
||||
required_paths = [
|
||||
"temp",
|
||||
"data",
|
||||
"data/metadata",
|
||||
"data/prompts",
|
||||
"data/scenario",
|
||||
"data/logs",
|
||||
|
||||
@@ -16,6 +16,10 @@ log_colors_config = {
|
||||
|
||||
|
||||
async def init_logging() -> logging.Logger:
|
||||
# 删除所有现有的logger
|
||||
for handler in logging.root.handlers[:]:
|
||||
logging.root.removeHandler(handler)
|
||||
|
||||
level = logging.INFO
|
||||
|
||||
if "DEBUG" in os.environ and os.environ["DEBUG"] in ["true", "1"]:
|
||||
@@ -46,7 +50,7 @@ async def init_logging() -> logging.Logger:
|
||||
|
||||
qcg_logger.debug("日志初始化完成,日志级别:%s" % level)
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, # 设置日志输出格式
|
||||
level=logging.CRITICAL, # 设置日志输出格式
|
||||
format="[DEPR][%(asctime)s.%(msecs)03d] %(pathname)s (%(lineno)d) - [%(levelname)s] :\n%(message)s",
|
||||
# 日志输出的格式
|
||||
# -8表示占位符,让输出左对齐,输出长度都为8位
|
||||
|
||||
@@ -9,13 +9,14 @@ import pydantic
|
||||
import mirai
|
||||
|
||||
from ..provider import entities as llm_entities
|
||||
from ..provider.requester import entities
|
||||
from ..provider.modelmgr import entities
|
||||
from ..provider.sysprompt import entities as sysprompt_entities
|
||||
from ..provider.tools import entities as tools_entities
|
||||
from ..platform import adapter as msadapter
|
||||
|
||||
|
||||
class LauncherTypes(enum.Enum):
|
||||
"""一个请求的发起者类型"""
|
||||
|
||||
PERSON = 'person'
|
||||
"""私聊"""
|
||||
@@ -31,45 +32,45 @@ class Query(pydantic.BaseModel):
|
||||
"""请求ID,添加进请求池时生成"""
|
||||
|
||||
launcher_type: LauncherTypes
|
||||
"""会话类型,platform设置"""
|
||||
"""会话类型,platform处理阶段设置"""
|
||||
|
||||
launcher_id: int
|
||||
"""会话ID,platform设置"""
|
||||
"""会话ID,platform处理阶段设置"""
|
||||
|
||||
sender_id: int
|
||||
"""发送者ID,platform设置"""
|
||||
"""发送者ID,platform处理阶段设置"""
|
||||
|
||||
message_event: mirai.MessageEvent
|
||||
"""事件,platform收到的事件"""
|
||||
"""事件,platform收到的原始事件"""
|
||||
|
||||
message_chain: mirai.MessageChain
|
||||
"""消息链,platform收到的消息链"""
|
||||
"""消息链,platform收到的原始消息链"""
|
||||
|
||||
adapter: msadapter.MessageSourceAdapter
|
||||
"""适配器对象"""
|
||||
"""消息平台适配器对象,单个app中可能启用了多个消息平台适配器,此对象表明发起此query的适配器"""
|
||||
|
||||
session: typing.Optional[Session] = None
|
||||
"""会话对象,由前置处理器设置"""
|
||||
"""会话对象,由前置处理器阶段设置"""
|
||||
|
||||
messages: typing.Optional[list[llm_entities.Message]] = []
|
||||
"""历史消息列表,由前置处理器设置"""
|
||||
"""历史消息列表,由前置处理器阶段设置"""
|
||||
|
||||
prompt: typing.Optional[sysprompt_entities.Prompt] = None
|
||||
"""情景预设内容,由前置处理器设置"""
|
||||
"""情景预设内容,由前置处理器阶段设置"""
|
||||
|
||||
user_message: typing.Optional[llm_entities.Message] = None
|
||||
"""此次请求的用户消息对象,由前置处理器设置"""
|
||||
"""此次请求的用户消息对象,由前置处理器阶段设置"""
|
||||
|
||||
use_model: typing.Optional[entities.LLMModelInfo] = None
|
||||
"""使用的模型,由前置处理器设置"""
|
||||
"""使用的模型,由前置处理器阶段设置"""
|
||||
|
||||
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
|
||||
"""使用的函数,由前置处理器设置"""
|
||||
"""使用的函数,由前置处理器阶段设置"""
|
||||
|
||||
resp_messages: typing.Optional[list[llm_entities.Message]] = []
|
||||
"""由provider生成的回复消息对象列表"""
|
||||
resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[mirai.MessageChain]] = []
|
||||
"""由Process阶段生成的回复消息对象列表"""
|
||||
|
||||
resp_message_chain: typing.Optional[mirai.MessageChain] = None
|
||||
resp_message_chain: typing.Optional[list[mirai.MessageChain]] = None
|
||||
"""回复消息链,从resp_messages包装而得"""
|
||||
|
||||
class Config:
|
||||
@@ -77,7 +78,7 @@ class Query(pydantic.BaseModel):
|
||||
|
||||
|
||||
class Conversation(pydantic.BaseModel):
|
||||
"""对话"""
|
||||
"""对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation,但只有一个当前使用的 Conversation"""
|
||||
|
||||
prompt: sysprompt_entities.Prompt
|
||||
|
||||
@@ -93,7 +94,7 @@ class Conversation(pydantic.BaseModel):
|
||||
|
||||
|
||||
class Session(pydantic.BaseModel):
|
||||
"""会话"""
|
||||
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""
|
||||
launcher_type: LauncherTypes
|
||||
|
||||
launcher_id: int
|
||||
@@ -111,6 +112,7 @@ class Session(pydantic.BaseModel):
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
semaphore: typing.Optional[asyncio.Semaphore] = None
|
||||
"""当前会话的信号量,用于限制并发"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
34
pkg/core/stage.py
Normal file
34
pkg/core/stage.py
Normal file
@@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from . import app
|
||||
|
||||
|
||||
preregistered_stages: dict[str, typing.Type[BootingStage]] = {}
|
||||
"""预注册的请求处理阶段。在初始化时,所有请求处理阶段类会被注册到此字典中。
|
||||
|
||||
当前阶段暂不支持扩展
|
||||
"""
|
||||
|
||||
def stage_class(
|
||||
name: str
|
||||
):
|
||||
def decorator(cls: typing.Type[BootingStage]) -> typing.Type[BootingStage]:
|
||||
preregistered_stages[name] = cls
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class BootingStage(abc.ABC):
|
||||
"""启动阶段
|
||||
"""
|
||||
name: str = None
|
||||
|
||||
@abc.abstractmethod
|
||||
async def run(self, ap: app.Application):
|
||||
"""启动
|
||||
"""
|
||||
pass
|
||||
96
pkg/core/stages/build_app.py
Normal file
96
pkg/core/stages/build_app.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
|
||||
from .. import stage, app
|
||||
from ...utils import version, proxy, announce, platform
|
||||
from ...audit.center import v2 as center_v2
|
||||
from ...audit import identifier
|
||||
from ...pipeline import pool, controller, stagemgr
|
||||
from ...plugin import manager as plugin_mgr
|
||||
from ...command import cmdmgr
|
||||
from ...provider.session import sessionmgr as llm_session_mgr
|
||||
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ...provider.sysprompt import sysprompt as llm_prompt_mgr
|
||||
from ...provider.tools import toolmgr as llm_tool_mgr
|
||||
from ...platform import manager as im_mgr
|
||||
|
||||
@stage.stage_class("BuildAppStage")
|
||||
class BuildAppStage(stage.BootingStage):
|
||||
"""构建应用阶段
|
||||
"""
|
||||
|
||||
async def run(self, ap: app.Application):
|
||||
"""构建app对象的各个组件对象并初始化
|
||||
"""
|
||||
|
||||
proxy_mgr = proxy.ProxyManager(ap)
|
||||
await proxy_mgr.initialize()
|
||||
ap.proxy_mgr = proxy_mgr
|
||||
|
||||
ver_mgr = version.VersionManager(ap)
|
||||
await ver_mgr.initialize()
|
||||
ap.ver_mgr = ver_mgr
|
||||
|
||||
center_v2_api = center_v2.V2CenterAPI(
|
||||
ap,
|
||||
backend_url=ap.system_cfg.data["qcg-center-url"],
|
||||
basic_info={
|
||||
"host_id": identifier.identifier["host_id"],
|
||||
"instance_id": identifier.identifier["instance_id"],
|
||||
"semantic_version": ver_mgr.get_current_version(),
|
||||
"platform": platform.get_platform(),
|
||||
},
|
||||
runtime_info={
|
||||
"admin_id": "{}".format(ap.system_cfg.data["admin-sessions"]),
|
||||
"msg_source": str([
|
||||
adapter_cfg['adapter'] if 'adapter' in adapter_cfg else 'unknown'
|
||||
for adapter_cfg in ap.platform_cfg.data['platform-adapters'] if adapter_cfg['enable']
|
||||
]),
|
||||
},
|
||||
)
|
||||
ap.ctr_mgr = center_v2_api
|
||||
|
||||
# 发送公告
|
||||
ann_mgr = announce.AnnouncementManager(ap)
|
||||
await ann_mgr.show_announcements()
|
||||
|
||||
ap.query_pool = pool.QueryPool()
|
||||
|
||||
await ap.ver_mgr.show_version_update()
|
||||
|
||||
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
|
||||
await plugin_mgr_inst.initialize()
|
||||
ap.plugin_mgr = plugin_mgr_inst
|
||||
await plugin_mgr_inst.load_plugins()
|
||||
|
||||
cmd_mgr_inst = cmdmgr.CommandManager(ap)
|
||||
await cmd_mgr_inst.initialize()
|
||||
ap.cmd_mgr = cmd_mgr_inst
|
||||
|
||||
llm_model_mgr_inst = llm_model_mgr.ModelManager(ap)
|
||||
await llm_model_mgr_inst.initialize()
|
||||
ap.model_mgr = llm_model_mgr_inst
|
||||
|
||||
llm_session_mgr_inst = llm_session_mgr.SessionManager(ap)
|
||||
await llm_session_mgr_inst.initialize()
|
||||
ap.sess_mgr = llm_session_mgr_inst
|
||||
|
||||
llm_prompt_mgr_inst = llm_prompt_mgr.PromptManager(ap)
|
||||
await llm_prompt_mgr_inst.initialize()
|
||||
ap.prompt_mgr = llm_prompt_mgr_inst
|
||||
|
||||
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(ap)
|
||||
await llm_tool_mgr_inst.initialize()
|
||||
ap.tool_mgr = llm_tool_mgr_inst
|
||||
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
|
||||
await im_mgr_inst.initialize()
|
||||
ap.platform_mgr = im_mgr_inst
|
||||
|
||||
stage_mgr = stagemgr.StageManager(ap)
|
||||
await stage_mgr.initialize()
|
||||
ap.stage_mgr = stage_mgr
|
||||
|
||||
|
||||
ctrl = controller.Controller(ap)
|
||||
ap.ctrl = ctrl
|
||||
31
pkg/core/stages/load_config.py
Normal file
31
pkg/core/stages/load_config.py
Normal file
@@ -0,0 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import stage, app
|
||||
from ..bootutils import config
|
||||
|
||||
|
||||
@stage.stage_class("LoadConfigStage")
|
||||
class LoadConfigStage(stage.BootingStage):
|
||||
"""加载配置文件阶段
|
||||
"""
|
||||
|
||||
async def run(self, ap: app.Application):
|
||||
"""启动
|
||||
"""
|
||||
ap.command_cfg = await config.load_json_config("data/config/command.json", "templates/command.json", completion=False)
|
||||
ap.pipeline_cfg = await config.load_json_config("data/config/pipeline.json", "templates/pipeline.json", completion=False)
|
||||
ap.platform_cfg = await config.load_json_config("data/config/platform.json", "templates/platform.json", completion=False)
|
||||
ap.provider_cfg = await config.load_json_config("data/config/provider.json", "templates/provider.json", completion=False)
|
||||
ap.system_cfg = await config.load_json_config("data/config/system.json", "templates/system.json", completion=False)
|
||||
|
||||
ap.plugin_setting_meta = await config.load_json_config("plugins/plugins.json", "templates/plugin-settings.json")
|
||||
await ap.plugin_setting_meta.dump_config()
|
||||
|
||||
ap.sensitive_meta = await config.load_json_config("data/metadata/sensitive-words.json", "templates/metadata/sensitive-words.json")
|
||||
await ap.sensitive_meta.dump_config()
|
||||
|
||||
ap.adapter_qq_botpy_meta = await config.load_json_config("data/metadata/adapter-qq-botpy.json", "templates/metadata/adapter-qq-botpy.json")
|
||||
await ap.adapter_qq_botpy_meta.dump_config()
|
||||
|
||||
ap.llm_models_meta = await config.load_json_config("data/metadata/llm-models.json", "templates/metadata/llm-models.json")
|
||||
await ap.llm_models_meta.dump_config()
|
||||
29
pkg/core/stages/migrate.py
Normal file
29
pkg/core/stages/migrate.py
Normal file
@@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
|
||||
from .. import stage, app
|
||||
from ...config import migration
|
||||
from ...config.migrations import m001_sensitive_word_migration, m002_openai_config_migration, m003_anthropic_requester_cfg_completion, m004_moonshot_cfg_completion
|
||||
from ...config.migrations import m005_deepseek_cfg_completion, m006_vision_config, m007_qcg_center_url, m008_ad_fixwin_config_migrate, m009_msg_truncator_cfg
|
||||
|
||||
|
||||
@stage.stage_class("MigrationStage")
|
||||
class MigrationStage(stage.BootingStage):
|
||||
"""迁移阶段
|
||||
"""
|
||||
|
||||
async def run(self, ap: app.Application):
|
||||
"""启动
|
||||
"""
|
||||
|
||||
migrations = migration.preregistered_migrations
|
||||
|
||||
# 按照迁移号排序
|
||||
migrations.sort(key=lambda x: x.number)
|
||||
|
||||
for migration_cls in migrations:
|
||||
migration_instance = migration_cls(ap)
|
||||
|
||||
if await migration_instance.need_migrate():
|
||||
await migration_instance.run()
|
||||
15
pkg/core/stages/setup_logger.py
Normal file
15
pkg/core/stages/setup_logger.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import stage, app
|
||||
from ..bootutils import log
|
||||
|
||||
|
||||
@stage.stage_class("SetupLoggerStage")
|
||||
class SetupLoggerStage(stage.BootingStage):
|
||||
"""设置日志器阶段
|
||||
"""
|
||||
|
||||
async def run(self, ap: app.Application):
|
||||
"""启动
|
||||
"""
|
||||
ap.logger = await log.init_logging()
|
||||
@@ -8,6 +8,10 @@ from ...config import manager as cfg_mgr
|
||||
|
||||
@stage.stage_class('BanSessionCheckStage')
|
||||
class BanSessionCheckStage(stage.PipelineStage):
|
||||
"""访问控制处理阶段
|
||||
|
||||
仅检查query中群号或个人号是否在访问控制列表中。
|
||||
"""
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
@@ -24,22 +28,24 @@ class BanSessionCheckStage(stage.PipelineStage):
|
||||
|
||||
sess_list = self.ap.pipeline_cfg.data['access-control'][mode]
|
||||
|
||||
if (query.launcher_type == 'group' and 'group_*' in sess_list) \
|
||||
or (query.launcher_type == 'person' and 'person_*' in sess_list):
|
||||
if (query.launcher_type.value == 'group' and 'group_*' in sess_list) \
|
||||
or (query.launcher_type.value == 'person' and 'person_*' in sess_list):
|
||||
found = True
|
||||
else:
|
||||
for sess in sess_list:
|
||||
if sess == f"{query.launcher_type}_{query.launcher_id}":
|
||||
if sess == f"{query.launcher_type.value}_{query.launcher_id}":
|
||||
found = True
|
||||
break
|
||||
|
||||
result = False
|
||||
ctn = False
|
||||
|
||||
if mode == 'blacklist':
|
||||
result = found
|
||||
if mode == 'whitelist':
|
||||
ctn = found
|
||||
else:
|
||||
ctn = not found
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE if not result else entities.ResultType.INTERRUPT,
|
||||
result_type=entities.ResultType.CONTINUE if ctn else entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
debug_notice=f'根据访问控制忽略消息: {query.launcher_type}_{query.launcher_id}' if result else ''
|
||||
console_notice=f'根据访问控制忽略消息: {query.launcher_type.value}_{query.launcher_id}' if not ctn else ''
|
||||
)
|
||||
|
||||
@@ -7,28 +7,50 @@ from ...core import app
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...config import manager as cfg_mgr
|
||||
from . import filter, entities as filter_entities
|
||||
from . import filter as filter_model, entities as filter_entities
|
||||
from .filters import cntignore, banwords, baiduexamine
|
||||
from ...provider import entities as llm_entities
|
||||
|
||||
|
||||
@stage.stage_class('PostContentFilterStage')
|
||||
@stage.stage_class('PreContentFilterStage')
|
||||
class ContentFilterStage(stage.PipelineStage):
|
||||
"""内容过滤阶段
|
||||
|
||||
filter_chain: list[filter.ContentFilter]
|
||||
前置:
|
||||
检查消息是否符合规则,不符合则拦截。
|
||||
改写:
|
||||
message_chain
|
||||
|
||||
后置:
|
||||
检查AI回复消息是否符合规则,可能进行改写,不符合则拦截。
|
||||
改写:
|
||||
query.resp_messages
|
||||
"""
|
||||
|
||||
filter_chain: list[filter_model.ContentFilter]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.filter_chain = []
|
||||
super().__init__(ap)
|
||||
|
||||
async def initialize(self):
|
||||
self.filter_chain.append(cntignore.ContentIgnore(self.ap))
|
||||
|
||||
filters_required = [
|
||||
"content-ignore",
|
||||
]
|
||||
|
||||
if self.ap.pipeline_cfg.data['check-sensitive-words']:
|
||||
self.filter_chain.append(banwords.BanWordFilter(self.ap))
|
||||
filters_required.append("ban-word-filter")
|
||||
|
||||
if self.ap.pipeline_cfg.data['baidu-cloud-examine']['enable']:
|
||||
self.filter_chain.append(baiduexamine.BaiduCloudExamine(self.ap))
|
||||
filters_required.append("baidu-cloud-examine")
|
||||
|
||||
for filter in filter_model.preregistered_filters:
|
||||
if filter.name in filters_required:
|
||||
self.filter_chain.append(
|
||||
filter(self.ap)
|
||||
)
|
||||
|
||||
for filter in self.filter_chain:
|
||||
await filter.initialize()
|
||||
@@ -120,14 +142,37 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
"""处理
|
||||
"""
|
||||
if stage_inst_name == 'PreContentFilterStage':
|
||||
|
||||
contain_non_text = False
|
||||
|
||||
for me in query.message_chain:
|
||||
if not isinstance(me, mirai.Plain):
|
||||
contain_non_text = True
|
||||
break
|
||||
|
||||
if contain_non_text:
|
||||
self.ap.logger.debug(f"消息中包含非文本消息,跳过内容过滤器检查。")
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
return await self._pre_process(
|
||||
str(query.message_chain).strip(),
|
||||
query
|
||||
)
|
||||
elif stage_inst_name == 'PostContentFilterStage':
|
||||
return await self._post_process(
|
||||
query.resp_messages[-1].content,
|
||||
query
|
||||
)
|
||||
# 仅处理 query.resp_messages[-1].content 是 str 的情况
|
||||
if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(query.resp_messages[-1].content, str):
|
||||
return await self._post_process(
|
||||
query.resp_messages[-1].content,
|
||||
query
|
||||
)
|
||||
else:
|
||||
self.ap.logger.debug(f"resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。")
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
raise ValueError(f'未知的 stage_inst_name: {stage_inst_name}')
|
||||
|
||||
@@ -4,6 +4,8 @@ import enum
|
||||
|
||||
import pydantic
|
||||
|
||||
from ...provider import entities as llm_entities
|
||||
|
||||
|
||||
class ResultLevel(enum.Enum):
|
||||
"""结果等级"""
|
||||
@@ -31,15 +33,24 @@ class EnableStage(enum.Enum):
|
||||
|
||||
class FilterResult(pydantic.BaseModel):
|
||||
level: ResultLevel
|
||||
"""结果等级
|
||||
|
||||
对于前置处理阶段,只要有任意一个返回 非PASS 的内容过滤器结果,就会中断处理。
|
||||
对于后置处理阶段,当且内容过滤器返回 BLOCK 时,会中断处理。
|
||||
"""
|
||||
|
||||
replacement: str
|
||||
"""替换后的消息"""
|
||||
"""替换后的文本消息
|
||||
|
||||
内容过滤器可以进行一些遮掩处理,然后把遮掩后的消息返回。
|
||||
若没有修改内容,也需要返回原消息。
|
||||
"""
|
||||
|
||||
user_notice: str
|
||||
"""不通过时,用户提示消息"""
|
||||
"""不通过时,若此值不为空,将对用户提示消息"""
|
||||
|
||||
console_notice: str
|
||||
"""不通过时,控制台提示消息"""
|
||||
"""不通过时,若此值不为空,将在控制台提示消息"""
|
||||
|
||||
|
||||
class ManagerResultLevel(enum.Enum):
|
||||
|
||||
@@ -1,12 +1,43 @@
|
||||
# 内容过滤器的抽象类
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from . import entities
|
||||
from ...provider import entities as llm_entities
|
||||
|
||||
|
||||
preregistered_filters: list[typing.Type[ContentFilter]] = []
|
||||
|
||||
|
||||
def filter_class(
|
||||
name: str
|
||||
) -> typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]:
|
||||
"""内容过滤器类装饰器
|
||||
|
||||
Args:
|
||||
name (str): 过滤器名称
|
||||
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[ContentFilter]], typing.Type[ContentFilter]]: 装饰器
|
||||
"""
|
||||
def decorator(cls: typing.Type[ContentFilter]) -> typing.Type[ContentFilter]:
|
||||
assert issubclass(cls, ContentFilter)
|
||||
|
||||
cls.name = name
|
||||
|
||||
preregistered_filters.append(cls)
|
||||
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class ContentFilter(metaclass=abc.ABCMeta):
|
||||
"""内容过滤器抽象类"""
|
||||
|
||||
name: str
|
||||
|
||||
ap: app.Application
|
||||
|
||||
@@ -16,6 +47,11 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
||||
@property
|
||||
def enable_stages(self):
|
||||
"""启用的阶段
|
||||
|
||||
默认为消息请求AI前后的两个阶段。
|
||||
|
||||
entity.EnableStage.PRE: 消息请求AI前,此时需要检查的内容是用户的输入消息。
|
||||
entity.EnableStage.POST: 消息请求AI后,此时需要检查的内容是AI的回复消息。
|
||||
"""
|
||||
return [
|
||||
entities.EnableStage.PRE,
|
||||
@@ -28,7 +64,17 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, message: str) -> entities.FilterResult:
|
||||
async def process(self, message: str=None, image_url=None) -> entities.FilterResult:
|
||||
"""处理消息
|
||||
|
||||
分为前后阶段,具体取决于 enable_stages 的值。
|
||||
对于内容过滤器来说,不需要考虑消息所处的阶段,只需要检查消息内容即可。
|
||||
|
||||
Args:
|
||||
message (str): 需要检查的内容
|
||||
image_url (str): 要检查的图片的 URL
|
||||
|
||||
Returns:
|
||||
entities.FilterResult: 过滤结果,具体内容请查看 entities.FilterResult 类的文档
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -10,6 +10,7 @@ BAIDU_EXAMINE_URL = "https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v
|
||||
BAIDU_EXAMINE_TOKEN_URL = "https://aip.baidubce.com/oauth/2.0/token"
|
||||
|
||||
|
||||
@filter_model.filter_class("baidu-cloud-examine")
|
||||
class BaiduCloudExamine(filter_model.ContentFilter):
|
||||
"""百度云内容审核"""
|
||||
|
||||
|
||||
@@ -6,34 +6,30 @@ from .. import entities
|
||||
from ....config import manager as cfg_mgr
|
||||
|
||||
|
||||
@filter_model.filter_class("ban-word-filter")
|
||||
class BanWordFilter(filter_model.ContentFilter):
|
||||
"""根据内容禁言"""
|
||||
|
||||
sensitive: cfg_mgr.ConfigManager
|
||||
"""根据内容过滤"""
|
||||
|
||||
async def initialize(self):
|
||||
self.sensitive = await cfg_mgr.load_json_config(
|
||||
"data/config/sensitive-words.json",
|
||||
"templates/sensitive-words.json"
|
||||
)
|
||||
pass
|
||||
|
||||
async def process(self, message: str) -> entities.FilterResult:
|
||||
found = False
|
||||
|
||||
for word in self.sensitive.data['words']:
|
||||
for word in self.ap.sensitive_meta.data['words']:
|
||||
match = re.findall(word, message)
|
||||
|
||||
if len(match) > 0:
|
||||
found = True
|
||||
|
||||
for i in range(len(match)):
|
||||
if self.sensitive.data['mask_word'] == "":
|
||||
if self.ap.sensitive_meta.data['mask_word'] == "":
|
||||
message = message.replace(
|
||||
match[i], self.sensitive.data['mask'] * len(match[i])
|
||||
match[i], self.ap.sensitive_meta.data['mask'] * len(match[i])
|
||||
)
|
||||
else:
|
||||
message = message.replace(
|
||||
match[i], self.sensitive.data['mask_word']
|
||||
match[i], self.ap.sensitive_meta.data['mask_word']
|
||||
)
|
||||
|
||||
return entities.FilterResult(
|
||||
|
||||
@@ -5,6 +5,7 @@ from .. import entities
|
||||
from .. import filter as filter_model
|
||||
|
||||
|
||||
@filter_model.filter_class("content-ignore")
|
||||
class ContentIgnore(filter_model.ContentFilter):
|
||||
"""根据内容忽略消息"""
|
||||
|
||||
|
||||
@@ -4,8 +4,8 @@ import asyncio
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
from . import app, entities
|
||||
from ..pipeline import entities as pipeline_entities
|
||||
from ..core import app, entities
|
||||
from . import entities as pipeline_entities
|
||||
from ..plugin import events
|
||||
|
||||
|
||||
@@ -68,7 +68,7 @@ class Controller:
|
||||
"""检查输出
|
||||
"""
|
||||
if result.user_notice:
|
||||
await self.ap.im_mgr.send(
|
||||
await self.ap.platform_mgr.send(
|
||||
query.message_event,
|
||||
result.user_notice,
|
||||
query.adapter
|
||||
@@ -85,7 +85,7 @@ class Controller:
|
||||
stage_index: int,
|
||||
query: entities.Query,
|
||||
):
|
||||
"""从指定阶段开始执行
|
||||
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
|
||||
|
||||
如何看懂这里为什么这么写?
|
||||
去问 GPT-4:
|
||||
@@ -15,6 +15,11 @@ from ...config import manager as cfg_mgr
|
||||
|
||||
@stage.stage_class("LongTextProcessStage")
|
||||
class LongTextProcessStage(stage.PipelineStage):
|
||||
"""长消息处理阶段
|
||||
|
||||
改写:
|
||||
- resp_message_chain
|
||||
"""
|
||||
|
||||
strategy_impl: strategy.LongTextStrategy
|
||||
|
||||
@@ -29,30 +34,44 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
if os.name == "nt":
|
||||
use_font = "C:/Windows/Fonts/msyh.ttc"
|
||||
if not os.path.exists(use_font):
|
||||
self.ap.logger.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。")
|
||||
self.ap.logger.warn("未找到字体文件,且无法使用Windows自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。")
|
||||
config['blob_message_strategy'] = "forward"
|
||||
else:
|
||||
self.ap.logger.info("使用Windows自带字体:" + use_font)
|
||||
config['font-path'] = use_font
|
||||
else:
|
||||
self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。")
|
||||
self.ap.logger.warn("未找到字体文件,且无法使用系统自带字体,更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。")
|
||||
|
||||
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
|
||||
except:
|
||||
traceback.print_exc()
|
||||
self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在config.py中调整相关设置。".format(use_font))
|
||||
self.ap.logger.error("加载字体文件失败({}),更换为转发消息组件以发送长消息,您可以在配置文件中调整相关设置。".format(use_font))
|
||||
|
||||
self.ap.platform_cfg.data['long-text-process']['strategy'] = "forward"
|
||||
|
||||
if config['strategy'] == 'image':
|
||||
self.strategy_impl = image.Text2ImageStrategy(self.ap)
|
||||
elif config['strategy'] == 'forward':
|
||||
self.strategy_impl = forward.ForwardComponentStrategy(self.ap)
|
||||
for strategy_cls in strategy.preregistered_strategies:
|
||||
if strategy_cls.name == config['strategy']:
|
||||
self.strategy_impl = strategy_cls(self.ap)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"未找到名为 {config['strategy']} 的长消息处理策略")
|
||||
|
||||
await self.strategy_impl.initialize()
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
if len(str(query.resp_message_chain)) > self.ap.platform_cfg.data['long-text-process']['threshold']:
|
||||
query.resp_message_chain = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain), query))
|
||||
# 检查是否包含非 Plain 组件
|
||||
contains_non_plain = False
|
||||
|
||||
for msg in query.resp_message_chain[-1]:
|
||||
if not isinstance(msg, Plain):
|
||||
contains_non_plain = True
|
||||
break
|
||||
|
||||
if contains_non_plain:
|
||||
self.ap.logger.debug("消息中包含非 Plain 组件,跳过长消息处理。")
|
||||
elif len(str(query.resp_message_chain[-1])) > self.ap.platform_cfg.data['long-text-process']['threshold']:
|
||||
query.resp_message_chain[-1] = MessageChain(await self.strategy_impl.process(str(query.resp_message_chain[-1]), query))
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
|
||||
@@ -36,6 +36,7 @@ class Forward(MessageComponent):
|
||||
return '[聊天记录]'
|
||||
|
||||
|
||||
@strategy_model.strategy_class("forward")
|
||||
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
|
||||
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
|
||||
|
||||
@@ -15,6 +15,7 @@ from .. import strategy as strategy_model
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@strategy_model.strategy_class("image")
|
||||
class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
|
||||
text_render_font: ImageFont.FreeTypeFont
|
||||
|
||||
@@ -9,7 +9,39 @@ from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
|
||||
|
||||
preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
|
||||
|
||||
|
||||
def strategy_class(
|
||||
name: str
|
||||
) -> typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]:
|
||||
"""长文本处理策略类装饰器
|
||||
|
||||
Args:
|
||||
name (str): 策略名称
|
||||
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[LongTextStrategy]], typing.Type[LongTextStrategy]]: 装饰器
|
||||
"""
|
||||
|
||||
def decorator(cls: typing.Type[LongTextStrategy]) -> typing.Type[LongTextStrategy]:
|
||||
assert issubclass(cls, LongTextStrategy)
|
||||
|
||||
cls.name = name
|
||||
|
||||
preregistered_strategies.append(cls)
|
||||
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class LongTextStrategy(metaclass=abc.ABCMeta):
|
||||
"""长文本处理策略抽象类
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
@@ -20,4 +52,15 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[MessageComponent]:
|
||||
"""处理长文本
|
||||
|
||||
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法
|
||||
|
||||
Args:
|
||||
message (str): 消息
|
||||
query (core_entities.Query): 此次请求的上下文对象
|
||||
|
||||
Returns:
|
||||
list[mirai.models.messages.MessageComponent]: 转换后的 YiriMirai 消息组件列表
|
||||
"""
|
||||
return []
|
||||
|
||||
35
pkg/pipeline/msgtrun/msgtrun.py
Normal file
35
pkg/pipeline/msgtrun/msgtrun.py
Normal file
@@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from . import truncator
|
||||
from .truncators import round
|
||||
|
||||
|
||||
@stage.stage_class("ConversationMessageTruncator")
|
||||
class ConversationMessageTruncator(stage.PipelineStage):
|
||||
"""会话消息截断器
|
||||
|
||||
用于截断会话消息链,以适应平台消息长度限制。
|
||||
"""
|
||||
trun: truncator.Truncator
|
||||
|
||||
async def initialize(self):
|
||||
use_method = self.ap.pipeline_cfg.data['msg-truncate']['method']
|
||||
|
||||
for trun in truncator.preregistered_truncators:
|
||||
if trun.name == use_method:
|
||||
self.trun = trun(self.ap)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"未知的截断器: {use_method}")
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理
|
||||
"""
|
||||
query = await self.trun.truncate(query)
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
56
pkg/pipeline/msgtrun/truncator.py
Normal file
56
pkg/pipeline/msgtrun/truncator.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import abc
|
||||
|
||||
from ...core import entities as core_entities, app
|
||||
|
||||
|
||||
preregistered_truncators: list[typing.Type[Truncator]] = []
|
||||
|
||||
|
||||
def truncator_class(
|
||||
name: str
|
||||
) -> typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]:
|
||||
"""截断器类装饰器
|
||||
|
||||
Args:
|
||||
name (str): 截断器名称
|
||||
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[Truncator]], typing.Type[Truncator]]: 装饰器
|
||||
"""
|
||||
def decorator(cls: typing.Type[Truncator]) -> typing.Type[Truncator]:
|
||||
assert issubclass(cls, Truncator)
|
||||
|
||||
cls.name = name
|
||||
|
||||
preregistered_truncators.append(cls)
|
||||
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class Truncator(abc.ABC):
|
||||
"""消息截断器基类
|
||||
"""
|
||||
|
||||
name: str
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
|
||||
"""截断
|
||||
|
||||
一般只需要操作query.messages,也可以扩展操作query.prompt, query.user_message。
|
||||
请勿操作其他字段。
|
||||
"""
|
||||
pass
|
||||
0
pkg/pipeline/msgtrun/truncators/__init__.py
Normal file
0
pkg/pipeline/msgtrun/truncators/__init__.py
Normal file
32
pkg/pipeline/msgtrun/truncators/round.py
Normal file
32
pkg/pipeline/msgtrun/truncators/round.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import truncator
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@truncator.truncator_class("round")
|
||||
class RoundTruncator(truncator.Truncator):
|
||||
"""前文回合数阶段器
|
||||
"""
|
||||
|
||||
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
|
||||
"""截断
|
||||
"""
|
||||
max_round = self.ap.pipeline_cfg.data['msg-truncate']['round']['max-round']
|
||||
|
||||
temp_messages = []
|
||||
|
||||
current_round = 0
|
||||
|
||||
# 从后往前遍历
|
||||
for msg in query.messages[::-1]:
|
||||
if current_round < max_round:
|
||||
temp_messages.append(msg)
|
||||
if msg.role == 'user':
|
||||
current_round += 1
|
||||
else:
|
||||
break
|
||||
|
||||
query.messages = temp_messages[::-1]
|
||||
|
||||
return query
|
||||
@@ -4,11 +4,12 @@ import asyncio
|
||||
|
||||
import mirai
|
||||
|
||||
from . import entities
|
||||
from ..core import entities
|
||||
from ..platform import adapter as msadapter
|
||||
|
||||
|
||||
class QueryPool:
|
||||
"""请求池,请求获得调度进入pipeline之前,保存在这里"""
|
||||
|
||||
query_id_counter: int = 0
|
||||
|
||||
@@ -42,7 +43,7 @@ class QueryPool:
|
||||
message_event=message_event,
|
||||
message_chain=message_chain,
|
||||
resp_messages=[],
|
||||
resp_message_chain=None,
|
||||
resp_message_chain=[],
|
||||
adapter=adapter
|
||||
)
|
||||
self.queries.append(query)
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import mirai
|
||||
|
||||
from .. import stage, entities, stagemgr
|
||||
from ...core import entities as core_entities
|
||||
from ...provider import entities as llm_entities
|
||||
@@ -8,7 +10,17 @@ from ...plugin import events
|
||||
|
||||
@stage.stage_class("PreProcessor")
|
||||
class PreProcessor(stage.PipelineStage):
|
||||
"""预处理器
|
||||
"""请求预处理阶段
|
||||
|
||||
签出会话、prompt、上文、模型、内容函数。
|
||||
|
||||
改写:
|
||||
- session
|
||||
- prompt
|
||||
- messages
|
||||
- user_message
|
||||
- use_model
|
||||
- use_funcs
|
||||
"""
|
||||
|
||||
async def process(
|
||||
@@ -27,21 +39,42 @@ class PreProcessor(stage.PipelineStage):
|
||||
query.prompt = conversation.prompt.copy()
|
||||
query.messages = conversation.messages.copy()
|
||||
|
||||
query.user_message = llm_entities.Message(
|
||||
role='user',
|
||||
content=str(query.message_chain).strip()
|
||||
)
|
||||
|
||||
query.use_model = conversation.use_model
|
||||
|
||||
query.use_funcs = conversation.use_funcs
|
||||
query.use_funcs = conversation.use_funcs if query.use_model.tool_call_supported else None
|
||||
|
||||
|
||||
# 检查vision是否启用,没启用就删除所有图片
|
||||
if not self.ap.provider_cfg.data['enable-vision'] or not query.use_model.vision_supported:
|
||||
for msg in query.messages:
|
||||
if isinstance(msg.content, list):
|
||||
for me in msg.content:
|
||||
if me.type == 'image_url':
|
||||
msg.content.remove(me)
|
||||
|
||||
content_list = []
|
||||
|
||||
for me in query.message_chain:
|
||||
if isinstance(me, mirai.Plain):
|
||||
content_list.append(
|
||||
llm_entities.ContentElement.from_text(me.text)
|
||||
)
|
||||
elif isinstance(me, mirai.Image):
|
||||
if self.ap.provider_cfg.data['enable-vision'] and query.use_model.vision_supported:
|
||||
if me.url is not None:
|
||||
content_list.append(
|
||||
llm_entities.ContentElement.from_image_url(str(me.url))
|
||||
)
|
||||
|
||||
query.user_message = llm_entities.Message( # TODO 适配多模态输入
|
||||
role='user',
|
||||
content=content_list
|
||||
)
|
||||
# =========== 触发事件 PromptPreProcessing
|
||||
session = query.session
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.PromptPreProcessing(
|
||||
session_name=f'{session.launcher_type.value}_{session.launcher_id}',
|
||||
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
default_prompt=query.prompt.messages,
|
||||
prompt=query.messages,
|
||||
query=query
|
||||
@@ -51,28 +84,6 @@ class PreProcessor(stage.PipelineStage):
|
||||
query.prompt.messages = event_ctx.event.default_prompt
|
||||
query.messages = event_ctx.event.prompt
|
||||
|
||||
# 根据模型max_tokens剪裁
|
||||
max_tokens = min(query.use_model.max_tokens, self.ap.pipeline_cfg.data['submit-messages-tokens'])
|
||||
|
||||
test_messages = query.prompt.messages + query.messages + [query.user_message]
|
||||
|
||||
while await query.use_model.tokenizer.count_token(test_messages, query.use_model) > max_tokens:
|
||||
# 前文都pop完了,还是大于max_tokens,由于prompt和user_messages不能删减,报错
|
||||
if len(query.prompt.messages) == 0:
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query,
|
||||
user_notice='输入内容过长,请减少情景预设或者输入内容长度',
|
||||
console_notice='输入内容过长,请减少情景预设或者输入内容长度,或者增大配置文件中的 submit-messages-tokens 项(但不能超过所用模型最大tokens数)'
|
||||
)
|
||||
|
||||
query.messages.pop(0) # pop第一个肯定是role=user的
|
||||
# 继续pop到第二个role=user前一个
|
||||
while len(query.messages) > 0 and query.messages[0].role != 'user':
|
||||
query.messages.pop(0)
|
||||
|
||||
test_messages = query.prompt.messages + query.messages + [query.user_message]
|
||||
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
|
||||
@@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import typing
|
||||
import time
|
||||
import traceback
|
||||
import json
|
||||
|
||||
import mirai
|
||||
|
||||
@@ -21,8 +22,6 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||
"""处理
|
||||
"""
|
||||
# 取session
|
||||
# 取conversation
|
||||
# 调API
|
||||
# 生成器
|
||||
|
||||
@@ -41,7 +40,9 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
if event_ctx.event.reply is not None:
|
||||
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
|
||||
mc = mirai.MessageChain(event_ctx.event.reply)
|
||||
|
||||
query.resp_messages.append(mc)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
@@ -65,17 +66,13 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
mirai.Plain(event_ctx.event.alter)
|
||||
])
|
||||
|
||||
query.messages.append(
|
||||
query.user_message
|
||||
)
|
||||
|
||||
text_length = 0
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
|
||||
async for result in query.use_model.requester.request(query):
|
||||
async for result in self.runner(query):
|
||||
query.resp_messages.append(result)
|
||||
|
||||
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
|
||||
@@ -87,6 +84,9 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
query.session.using_conversation.messages.append(query.user_message)
|
||||
query.session.using_conversation.messages.extend(query.resp_messages)
|
||||
except Exception as e:
|
||||
|
||||
self.ap.logger.error(f'对话({query.query_id})请求失败: {str(e)}')
|
||||
@@ -99,8 +99,6 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
debug_notice=traceback.format_exc()
|
||||
)
|
||||
finally:
|
||||
query.session.using_conversation.messages.append(query.user_message)
|
||||
query.session.using_conversation.messages.extend(query.resp_messages)
|
||||
|
||||
await self.ap.ctr_mgr.usage.post_query_record(
|
||||
session_type=query.session.launcher_type.value,
|
||||
@@ -111,3 +109,64 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
response_seconds=int(time.time() - start_time),
|
||||
retry_times=-1,
|
||||
)
|
||||
|
||||
async def runner(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""执行一个请求处理过程中的LLM接口请求、函数调用的循环
|
||||
|
||||
这是临时处理方案,后续可能改为使用LangChain或者自研的工作流处理器
|
||||
"""
|
||||
await query.use_model.requester.preprocess(query)
|
||||
|
||||
pending_tool_calls = []
|
||||
|
||||
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
|
||||
|
||||
# 首次请求
|
||||
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
|
||||
req_messages.append(msg)
|
||||
|
||||
# 持续请求,只要还有待处理的工具调用就继续处理调用
|
||||
while pending_tool_calls:
|
||||
for tool_call in pending_tool_calls:
|
||||
try:
|
||||
func = tool_call.function
|
||||
|
||||
parameters = json.loads(func.arguments)
|
||||
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(
|
||||
query, func.name, parameters
|
||||
)
|
||||
|
||||
msg = llm_entities.Message(
|
||||
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
|
||||
)
|
||||
|
||||
yield msg
|
||||
|
||||
req_messages.append(msg)
|
||||
except Exception as e:
|
||||
# 工具调用出错,添加一个报错信息到 req_messages
|
||||
err_msg = llm_entities.Message(
|
||||
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
|
||||
)
|
||||
|
||||
yield err_msg
|
||||
|
||||
req_messages.append(err_msg)
|
||||
|
||||
# 处理完所有调用,再次请求
|
||||
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
|
||||
|
||||
yield msg
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
|
||||
req_messages.append(msg)
|
||||
|
||||
@@ -19,15 +19,16 @@ class CommandHandler(handler.MessageHandler):
|
||||
"""处理
|
||||
"""
|
||||
|
||||
event_class = events.PersonCommandSent if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupCommandSent
|
||||
|
||||
command_text = str(query.message_chain).strip()[1:]
|
||||
|
||||
privilege = 1
|
||||
|
||||
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.system_cfg.data['admin-sessions']:
|
||||
privilege = 2
|
||||
|
||||
spt = str(query.message_chain).strip().split(' ')
|
||||
spt = command_text.split(' ')
|
||||
|
||||
event_class = events.PersonCommandSent if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupCommandSent
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=event_class(
|
||||
@@ -47,12 +48,7 @@ class CommandHandler(handler.MessageHandler):
|
||||
if event_ctx.event.reply is not None:
|
||||
mc = mirai.MessageChain(event_ctx.event.reply)
|
||||
|
||||
query.resp_messages.append(
|
||||
llm_entities.Message(
|
||||
role='command',
|
||||
content=str(mc),
|
||||
)
|
||||
)
|
||||
query.resp_messages.append(mc)
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
@@ -73,17 +69,12 @@ class CommandHandler(handler.MessageHandler):
|
||||
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
command_text = str(query.message_chain).strip()[1:]
|
||||
|
||||
async for ret in self.ap.cmd_mgr.execute(
|
||||
command_text=command_text,
|
||||
query=query,
|
||||
session=session
|
||||
):
|
||||
if ret.error is not None:
|
||||
# query.resp_message_chain = mirai.MessageChain([
|
||||
# mirai.Plain(str(ret.error))
|
||||
# ])
|
||||
query.resp_messages.append(
|
||||
llm_entities.Message(
|
||||
role='command',
|
||||
@@ -97,18 +88,28 @@ class CommandHandler(handler.MessageHandler):
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
elif ret.text is not None:
|
||||
# query.resp_message_chain = mirai.MessageChain([
|
||||
# mirai.Plain(ret.text)
|
||||
# ])
|
||||
elif ret.text is not None or ret.image_url is not None:
|
||||
|
||||
content: list[llm_entities.ContentElement]= []
|
||||
|
||||
if ret.text is not None:
|
||||
content.append(
|
||||
llm_entities.ContentElement.from_text(ret.text)
|
||||
)
|
||||
|
||||
if ret.image_url is not None:
|
||||
content.append(
|
||||
llm_entities.ContentElement.from_image_url(ret.image_url)
|
||||
)
|
||||
|
||||
query.resp_messages.append(
|
||||
llm_entities.Message(
|
||||
role='command',
|
||||
content=ret.text,
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
|
||||
self.ap.logger.info(f'命令返回: {self.cut_str(ret.text)}')
|
||||
self.ap.logger.info(f'命令返回: {self.cut_str(str(content[0]))}')
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
|
||||
@@ -11,6 +11,13 @@ from ...config import manager as cfg_mgr
|
||||
|
||||
@stage.stage_class("MessageProcessor")
|
||||
class Processor(stage.PipelineStage):
|
||||
"""请求实际处理阶段
|
||||
|
||||
通过命令处理器和聊天处理器处理消息。
|
||||
|
||||
改写:
|
||||
- resp_messages
|
||||
"""
|
||||
|
||||
cmd_handler: handler.MessageHandler
|
||||
|
||||
|
||||
@@ -1,10 +1,26 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
|
||||
|
||||
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
|
||||
|
||||
def algo_class(name: str):
|
||||
|
||||
def decorator(cls: typing.Type[ReteLimitAlgo]) -> typing.Type[ReteLimitAlgo]:
|
||||
cls.name = name
|
||||
preregistered_algos.append(cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||
"""限流算法抽象类"""
|
||||
|
||||
name: str = None
|
||||
|
||||
ap: app.Application
|
||||
|
||||
@@ -16,9 +32,27 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
||||
|
||||
@abc.abstractmethod
|
||||
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
|
||||
"""进入处理流程
|
||||
|
||||
这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。
|
||||
|
||||
Args:
|
||||
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
|
||||
launcher_id (int): 请求者ID
|
||||
|
||||
Returns:
|
||||
bool: 是否允许进入处理流程,若返回false,则直接丢弃该请求
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def release_access(self, launcher_type: str, launcher_id: int):
|
||||
"""退出处理流程
|
||||
|
||||
Args:
|
||||
launcher_type (str): 请求者类型 群聊为 group 私聊为 person
|
||||
launcher_id (int): 请求者ID
|
||||
"""
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -1,24 +1,22 @@
|
||||
# 固定窗口算法
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from .. import algo
|
||||
|
||||
|
||||
# 固定窗口算法
|
||||
class SessionContainer:
|
||||
|
||||
wait_lock: asyncio.Lock
|
||||
|
||||
records: dict[int, int]
|
||||
"""访问记录,key为每分钟的起始时间戳,value为访问次数"""
|
||||
"""访问记录,key为每窗口长度的起始时间戳,value为访问次数"""
|
||||
|
||||
def __init__(self):
|
||||
self.wait_lock = asyncio.Lock()
|
||||
self.records = {}
|
||||
|
||||
|
||||
@algo.algo_class("fixwin")
|
||||
class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
|
||||
containers_lock: asyncio.Lock
|
||||
@@ -46,30 +44,34 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
||||
|
||||
# 等待锁
|
||||
async with container.wait_lock:
|
||||
|
||||
# 获取窗口大小和限制
|
||||
window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['window-size']
|
||||
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']['limit']
|
||||
|
||||
if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
|
||||
window_size = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['window-size']
|
||||
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]['limit']
|
||||
|
||||
# 获取当前时间戳
|
||||
now = int(time.time())
|
||||
|
||||
# 获取当前分钟的起始时间戳
|
||||
now = now - now % 60
|
||||
# 获取当前窗口的起始时间戳
|
||||
now = now - now % window_size
|
||||
|
||||
# 获取当前分钟的访问次数
|
||||
# 获取当前窗口的访问次数
|
||||
count = container.records.get(now, 0)
|
||||
|
||||
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin']['default']
|
||||
|
||||
if session_name in self.ap.pipeline_cfg.data['rate-limit']['fixwin']:
|
||||
limitation = self.ap.pipeline_cfg.data['rate-limit']['fixwin'][session_name]
|
||||
|
||||
# 如果访问次数超过了限制
|
||||
if count >= limitation:
|
||||
if self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'drop':
|
||||
return False
|
||||
elif self.ap.pipeline_cfg.data['rate-limit']['strategy'] == 'wait':
|
||||
# 等待下一分钟
|
||||
await asyncio.sleep(60 - time.time() % 60)
|
||||
# 等待下一窗口
|
||||
await asyncio.sleep(window_size - time.time() % window_size)
|
||||
|
||||
now = int(time.time())
|
||||
now = now - now % 60
|
||||
now = now - now % window_size
|
||||
|
||||
if now not in container.records:
|
||||
container.records = {}
|
||||
|
||||
@@ -11,11 +11,27 @@ from ...core import entities as core_entities
|
||||
@stage.stage_class("RequireRateLimitOccupancy")
|
||||
@stage.stage_class("ReleaseRateLimitOccupancy")
|
||||
class RateLimit(stage.PipelineStage):
|
||||
"""限速器控制阶段
|
||||
|
||||
不改写query,只检查是否需要限速。
|
||||
"""
|
||||
|
||||
algo: algo.ReteLimitAlgo
|
||||
|
||||
async def initialize(self):
|
||||
self.algo = fixedwin.FixedWindowAlgo(self.ap)
|
||||
|
||||
algo_name = self.ap.pipeline_cfg.data['rate-limit']['algo']
|
||||
|
||||
algo_class = None
|
||||
|
||||
for algo_cls in algo.preregistered_algos:
|
||||
if algo_cls.name == algo_name:
|
||||
algo_class = algo_cls
|
||||
break
|
||||
else:
|
||||
raise ValueError(f'未知的限速算法: {algo_name}')
|
||||
|
||||
self.algo = algo_class(self.ap)
|
||||
await self.algo.initialize()
|
||||
|
||||
async def process(
|
||||
@@ -46,7 +62,7 @@ class RateLimit(stage.PipelineStage):
|
||||
)
|
||||
elif stage_inst_name == "ReleaseRateLimitOccupancy":
|
||||
await self.algo.release_access(
|
||||
query.launcher_type,
|
||||
query.launcher_type.value,
|
||||
query.launcher_id,
|
||||
)
|
||||
return entities.StageProcessResult(
|
||||
|
||||
@@ -29,9 +29,9 @@ class SendResponseBackStage(stage.PipelineStage):
|
||||
|
||||
await asyncio.sleep(random_delay)
|
||||
|
||||
await self.ap.im_mgr.send(
|
||||
await self.ap.platform_mgr.send(
|
||||
query.message_event,
|
||||
query.resp_message_chain,
|
||||
query.resp_message_chain[-1],
|
||||
adapter=query.adapter
|
||||
)
|
||||
|
||||
|
||||
@@ -14,26 +14,27 @@ from ...config import manager as cfg_mgr
|
||||
@stage.stage_class("GroupRespondRuleCheckStage")
|
||||
class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
"""群组响应规则检查器
|
||||
|
||||
仅检查群消息是否符合规则。
|
||||
"""
|
||||
|
||||
rule_matchers: list[rule.GroupRespondRule]
|
||||
"""检查器实例"""
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化检查器
|
||||
"""
|
||||
self.rule_matchers = [
|
||||
atbot.AtBotRule(self.ap),
|
||||
prefix.PrefixRule(self.ap),
|
||||
regexp.RegExpRule(self.ap),
|
||||
random.RandomRespRule(self.ap),
|
||||
]
|
||||
|
||||
for rule_matcher in self.rule_matchers:
|
||||
await rule_matcher.initialize()
|
||||
self.rule_matchers = []
|
||||
|
||||
for rule_matcher in rule.preregisetered_rules:
|
||||
rule_inst = rule_matcher(self.ap)
|
||||
await rule_inst.initialize()
|
||||
self.rule_matchers.append(rule_inst)
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
|
||||
if query.launcher_type.value != 'group':
|
||||
if query.launcher_type.value != 'group': # 只处理群消息
|
||||
return entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
@@ -43,8 +44,8 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
||||
|
||||
use_rule = rules['default']
|
||||
|
||||
if str(query.launcher_id) in use_rule:
|
||||
use_rule = use_rule[str(query.launcher_id)]
|
||||
if str(query.launcher_id) in rules:
|
||||
use_rule = rules[str(query.launcher_id)]
|
||||
|
||||
for rule_matcher in self.rule_matchers: # 任意一个匹配就放行
|
||||
res = await rule_matcher.match(str(query.message_chain), query.message_chain, use_rule, query)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
import mirai
|
||||
|
||||
@@ -7,9 +8,20 @@ from ...core import app, entities as core_entities
|
||||
from . import entities
|
||||
|
||||
|
||||
preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
|
||||
|
||||
def rule_class(name: str):
|
||||
def decorator(cls: typing.Type[GroupRespondRule]) -> typing.Type[GroupRespondRule]:
|
||||
cls.name = name
|
||||
preregisetered_rules.append(cls)
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
|
||||
class GroupRespondRule(metaclass=abc.ABCMeta):
|
||||
"""群组响应规则的抽象类
|
||||
"""
|
||||
name: str
|
||||
|
||||
ap: app.Application
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@rule_model.rule_class("at-bot")
|
||||
class AtBotRule(rule_model.GroupRespondRule):
|
||||
|
||||
async def match(
|
||||
@@ -19,6 +20,10 @@ class AtBotRule(rule_model.GroupRespondRule):
|
||||
|
||||
if message_chain.has(mirai.At(query.adapter.bot_account_id)) and rule_dict['at']:
|
||||
message_chain.remove(mirai.At(query.adapter.bot_account_id))
|
||||
|
||||
if message_chain.has(mirai.At(query.adapter.bot_account_id)): # 回复消息时会at两次,检查并删除重复的
|
||||
message_chain.remove(mirai.At(query.adapter.bot_account_id))
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
matching=True,
|
||||
replacement=message_chain,
|
||||
|
||||
@@ -5,6 +5,7 @@ from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@rule_model.rule_class("prefix")
|
||||
class PrefixRule(rule_model.GroupRespondRule):
|
||||
|
||||
async def match(
|
||||
@@ -19,11 +20,14 @@ class PrefixRule(rule_model.GroupRespondRule):
|
||||
for prefix in prefixes:
|
||||
if message_text.startswith(prefix):
|
||||
|
||||
# 查找第一个plain元素
|
||||
for me in message_chain:
|
||||
if isinstance(me, mirai.Plain):
|
||||
me.text = me.text[len(prefix):]
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
matching=True,
|
||||
replacement=mirai.MessageChain([
|
||||
mirai.Plain(message_text[len(prefix):])
|
||||
]),
|
||||
replacement=message_chain,
|
||||
)
|
||||
|
||||
return entities.RuleJudgeResult(
|
||||
|
||||
@@ -7,6 +7,7 @@ from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@rule_model.rule_class("random")
|
||||
class RandomRespRule(rule_model.GroupRespondRule):
|
||||
|
||||
async def match(
|
||||
|
||||
@@ -7,6 +7,7 @@ from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
|
||||
|
||||
@rule_model.rule_class("regexp")
|
||||
class RegExpRule(rule_model.GroupRespondRule):
|
||||
|
||||
async def match(
|
||||
|
||||
@@ -13,20 +13,23 @@ from .respback import respback
|
||||
from .wrapper import wrapper
|
||||
from .preproc import preproc
|
||||
from .ratelimit import ratelimit
|
||||
from .msgtrun import msgtrun
|
||||
|
||||
|
||||
# 请求处理阶段顺序
|
||||
stage_order = [
|
||||
"GroupRespondRuleCheckStage",
|
||||
"BanSessionCheckStage",
|
||||
"PreContentFilterStage",
|
||||
"PreProcessor",
|
||||
"RequireRateLimitOccupancy",
|
||||
"MessageProcessor",
|
||||
"ReleaseRateLimitOccupancy",
|
||||
"PostContentFilterStage",
|
||||
"ResponseWrapper",
|
||||
"LongTextProcessStage",
|
||||
"SendResponseBackStage",
|
||||
"GroupRespondRuleCheckStage", # 群响应规则检查
|
||||
"BanSessionCheckStage", # 封禁会话检查
|
||||
"PreContentFilterStage", # 内容过滤前置阶段
|
||||
"PreProcessor", # 预处理器
|
||||
"ConversationMessageTruncator", # 会话消息截断器
|
||||
"RequireRateLimitOccupancy", # 请求速率限制占用
|
||||
"MessageProcessor", # 处理器
|
||||
"ReleaseRateLimitOccupancy", # 释放速率限制占用
|
||||
"PostContentFilterStage", # 内容过滤后置阶段
|
||||
"ResponseWrapper", # 响应包装器
|
||||
"LongTextProcessStage", # 长文本处理
|
||||
"SendResponseBackStage", # 发送响应
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,13 @@ from ...plugin import events
|
||||
|
||||
@stage.stage_class("ResponseWrapper")
|
||||
class ResponseWrapper(stage.PipelineStage):
|
||||
"""回复包装阶段
|
||||
|
||||
把回复的 message 包装成人类识读的形式。
|
||||
|
||||
改写:
|
||||
- resp_message_chain
|
||||
"""
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
@@ -26,67 +33,48 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
"""处理
|
||||
"""
|
||||
|
||||
if query.resp_messages[-1].role == 'command':
|
||||
query.resp_message_chain = mirai.MessageChain("[bot] "+query.resp_messages[-1].content)
|
||||
# 如果 resp_messages[-1] 已经是 MessageChain 了
|
||||
if isinstance(query.resp_messages[-1], mirai.MessageChain):
|
||||
query.resp_message_chain.append(query.resp_messages[-1])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
if query.resp_messages[-1].role == 'assistant':
|
||||
result = query.resp_messages[-1]
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
if query.resp_messages[-1].role == 'command':
|
||||
# query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content))
|
||||
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] '))
|
||||
|
||||
reply_text = ''
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
elif query.resp_messages[-1].role == 'plugin':
|
||||
# if not isinstance(query.resp_messages[-1].content, mirai.MessageChain):
|
||||
# query.resp_message_chain.append(mirai.MessageChain(query.resp_messages[-1].content))
|
||||
# else:
|
||||
# query.resp_message_chain.append(query.resp_messages[-1].content)
|
||||
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain())
|
||||
|
||||
if result.content is not None: # 有内容
|
||||
reply_text = result.content
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
|
||||
# ============= 触发插件事件 ===============
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
session=session,
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
|
||||
query=query
|
||||
)
|
||||
)
|
||||
if event_ctx.is_prevented_default():
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
if query.resp_messages[-1].role == 'assistant':
|
||||
result = query.resp_messages[-1]
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
|
||||
reply_text = ''
|
||||
|
||||
else:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
if result.tool_calls is not None: # 有函数调用
|
||||
|
||||
function_names = [tc.function.name for tc in result.tool_calls]
|
||||
|
||||
reply_text = f'调用函数 {".".join(function_names)}...'
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
|
||||
if self.ap.platform_cfg.data['track-function-calls']:
|
||||
if result.content: # 有内容
|
||||
reply_text = str(result.get_content_mirai_message_chain())
|
||||
|
||||
# ============= 触发插件事件 ===============
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
@@ -100,7 +88,6 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
query=query
|
||||
)
|
||||
)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
@@ -109,13 +96,56 @@ class ResponseWrapper(stage.PipelineStage):
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain(event_ctx.event.reply)
|
||||
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
|
||||
|
||||
else:
|
||||
|
||||
query.resp_message_chain = mirai.MessageChain([mirai.Plain(reply_text)])
|
||||
query.resp_message_chain.append(result.get_content_mirai_message_chain())
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
if result.tool_calls is not None: # 有函数调用
|
||||
|
||||
function_names = [tc.function.name for tc in result.tool_calls]
|
||||
|
||||
reply_text = f'调用函数 {".".join(function_names)}...'
|
||||
|
||||
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
|
||||
|
||||
if self.ap.platform_cfg.data['track-function-calls']:
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.NormalMessageResponded(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
session=session,
|
||||
prefix='',
|
||||
response_text=reply_text,
|
||||
finish_reason='stop',
|
||||
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None else [],
|
||||
query=query
|
||||
)
|
||||
)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.INTERRUPT,
|
||||
new_query=query
|
||||
)
|
||||
else:
|
||||
if event_ctx.event.reply is not None:
|
||||
|
||||
query.resp_message_chain.append(mirai.MessageChain(event_ctx.event.reply))
|
||||
|
||||
else:
|
||||
|
||||
query.resp_message_chain.append(mirai.MessageChain([mirai.Plain(reply_text)]))
|
||||
|
||||
yield entities.StageProcessResult(
|
||||
result_type=entities.ResultType.CONTINUE,
|
||||
new_query=query
|
||||
)
|
||||
|
||||
@@ -14,6 +14,14 @@ preregistered_adapters: list[typing.Type[MessageSourceAdapter]] = []
|
||||
def adapter_class(
|
||||
name: str
|
||||
):
|
||||
"""消息平台适配器类装饰器
|
||||
|
||||
Args:
|
||||
name (str): 适配器名称
|
||||
|
||||
Returns:
|
||||
typing.Callable[[typing.Type[MessageSourceAdapter]], typing.Type[MessageSourceAdapter]]: 装饰器
|
||||
"""
|
||||
def decorator(cls: typing.Type[MessageSourceAdapter]) -> typing.Type[MessageSourceAdapter]:
|
||||
cls.name = name
|
||||
preregistered_adapters.append(cls)
|
||||
@@ -22,15 +30,24 @@ def adapter_class(
|
||||
|
||||
|
||||
class MessageSourceAdapter(metaclass=abc.ABCMeta):
|
||||
"""消息平台适配器基类"""
|
||||
|
||||
name: str
|
||||
|
||||
bot_account_id: int
|
||||
"""机器人账号ID,需要在初始化时设置"""
|
||||
|
||||
config: dict
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application):
|
||||
"""初始化适配器
|
||||
|
||||
Args:
|
||||
config (dict): 对应的配置
|
||||
ap (app.Application): 应用上下文
|
||||
"""
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
|
||||
@@ -40,7 +57,7 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
|
||||
target_id: str,
|
||||
message: mirai.MessageChain
|
||||
):
|
||||
"""发送消息
|
||||
"""主动发送消息
|
||||
|
||||
Args:
|
||||
target_type (str): 目标类型,`person`或`group`
|
||||
|
||||
@@ -82,8 +82,8 @@ class PlatformManager:
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.GroupMessageReceived(
|
||||
launcher_type='person',
|
||||
launcher_id=event.sender.id,
|
||||
launcher_type='group',
|
||||
launcher_id=event.group.id,
|
||||
sender_id=event.sender.id,
|
||||
message_chain=event.message_chain,
|
||||
query=None
|
||||
@@ -146,7 +146,7 @@ class PlatformManager:
|
||||
if len(self.adapters) == 0:
|
||||
self.ap.logger.warning('未运行平台适配器,请根据文档配置并启用平台适配器。')
|
||||
|
||||
async def send(self, event, msg, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True):
|
||||
async def send(self, event: mirai.MessageEvent, msg: mirai.MessageChain, adapter: msadapter.MessageSourceAdapter, check_quote=True, check_at_sender=True):
|
||||
|
||||
if check_at_sender and self.ap.platform_cfg.data['at-sender'] and isinstance(event, GroupMessage):
|
||||
|
||||
@@ -163,25 +163,6 @@ class PlatformManager:
|
||||
quote_origin=True if self.ap.platform_cfg.data['quote-origin'] and check_quote else False
|
||||
)
|
||||
|
||||
# 通知系统管理员
|
||||
# TODO delete
|
||||
# async def notify_admin(self, message: str):
|
||||
# await self.notify_admin_message_chain(MessageChain([Plain("[bot]{}".format(message))]))
|
||||
|
||||
# async def notify_admin_message_chain(self, message: mirai.MessageChain):
|
||||
# if self.ap.system_cfg.data['admin-sessions'] != []:
|
||||
|
||||
# admin_list = []
|
||||
# for admin in self.ap.system_cfg.data['admin-sessions']:
|
||||
# admin_list.append(admin)
|
||||
|
||||
# for adm in admin_list:
|
||||
# self.adapter.send_message(
|
||||
# adm.split("_")[0],
|
||||
# adm.split("_")[1],
|
||||
# message
|
||||
# )
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
tasks = []
|
||||
|
||||
@@ -30,7 +30,16 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
||||
msg_id = msg.id
|
||||
msg_time = msg.time
|
||||
elif type(msg) is mirai.Image:
|
||||
msg_list.append(aiocqhttp.MessageSegment.image(msg.path))
|
||||
arg = ''
|
||||
if msg.base64:
|
||||
arg = msg.base64
|
||||
msg_list.append(aiocqhttp.MessageSegment.image(f"base64://{arg}"))
|
||||
elif msg.url:
|
||||
arg = msg.url
|
||||
msg_list.append(aiocqhttp.MessageSegment.image(arg))
|
||||
elif msg.path:
|
||||
arg = msg.path
|
||||
msg_list.append(aiocqhttp.MessageSegment.image(arg))
|
||||
elif type(msg) is mirai.At:
|
||||
msg_list.append(aiocqhttp.MessageSegment.at(msg.target))
|
||||
elif type(msg) is mirai.AtAll:
|
||||
@@ -40,7 +49,6 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
||||
elif type(msg) is mirai.Voice:
|
||||
msg_list.append(aiocqhttp.MessageSegment.record(msg.path))
|
||||
elif type(msg) is forward.Forward:
|
||||
# print("aiocqhttp 暂不支持转发消息组件的转换,使用普通消息链发送")
|
||||
|
||||
for node in msg.node_list:
|
||||
msg_list.extend(AiocqhttpMessageConverter.yiri2target(node.message_chain)[0])
|
||||
@@ -170,7 +178,7 @@ class AiocqhttpEventConverter(adapter.EventConverter):
|
||||
name=event.sender["nickname"],
|
||||
permission=mirai.models.entities.Permission.Member,
|
||||
),
|
||||
special_title=event.sender["title"],
|
||||
special_title=event.sender["title"] if "title" in event.sender else "",
|
||||
join_timestamp=0,
|
||||
last_speak_timestamp=0,
|
||||
mute_time_remaining=0,
|
||||
@@ -216,13 +224,21 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
|
||||
|
||||
self.ap = ap
|
||||
|
||||
self.bot = aiocqhttp.CQHttp()
|
||||
if "access-token" in config:
|
||||
self.bot = aiocqhttp.CQHttp(access_token=config["access-token"])
|
||||
del self.config["access-token"]
|
||||
else:
|
||||
self.bot = aiocqhttp.CQHttp()
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: mirai.MessageChain
|
||||
):
|
||||
# TODO 实现发送消息
|
||||
return super().send_message(target_type, target_id, message)
|
||||
aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
|
||||
|
||||
if target_type == "group":
|
||||
await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg)
|
||||
elif target_type == "person":
|
||||
await self.bot.send_private_msg(user_id=int(target_id), message=aiocq_msg)
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
@@ -230,7 +246,6 @@ class AiocqhttpAdapter(adapter.MessageSourceAdapter):
|
||||
message: mirai.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
|
||||
aiocq_event = AiocqhttpEventConverter.yiri2target(message_source, self.bot_account_id)
|
||||
aiocq_msg = AiocqhttpMessageConverter.yiri2target(message)[0]
|
||||
if quote_origin:
|
||||
|
||||
@@ -24,6 +24,8 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
||||
msg_list = message_chain.__root__
|
||||
elif type(message_chain) is list:
|
||||
msg_list = message_chain
|
||||
elif type(message_chain) is str:
|
||||
msg_list = [mirai.Plain(message_chain)]
|
||||
else:
|
||||
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))
|
||||
|
||||
@@ -320,7 +322,7 @@ class NakuruProjectAdapter(adapter_model.MessageSourceAdapter):
|
||||
proxies=None
|
||||
)
|
||||
if resp.status_code == 403:
|
||||
raise Exception("go-cqhttp拒绝访问,请检查config.py中nakuru_config的token是否与go-cqhttp设置的access-token匹配")
|
||||
raise Exception("go-cqhttp拒绝访问,请检查配置文件中nakuru适配器的配置")
|
||||
self.bot_account_id = int(resp.json()['data']['user_id'])
|
||||
except Exception as e:
|
||||
raise Exception("获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确")
|
||||
|
||||
@@ -17,6 +17,7 @@ import botpy.types.message as botpy_message_type
|
||||
from .. import adapter as adapter_model
|
||||
from ...pipeline.longtext.strategies import forward
|
||||
from ...core import app
|
||||
from ...config import manager as cfg_mgr
|
||||
|
||||
|
||||
class OfficialGroupMessage(mirai.GroupMessage):
|
||||
@@ -34,6 +35,7 @@ cached_message_ids = {}
|
||||
|
||||
id_index = 0
|
||||
|
||||
|
||||
def save_msg_id(message_id: str) -> int:
|
||||
"""保存消息id"""
|
||||
global id_index, cached_message_ids
|
||||
@@ -43,43 +45,82 @@ def save_msg_id(message_id: str) -> int:
|
||||
cached_message_ids[str(crt_index)] = message_id
|
||||
return crt_index
|
||||
|
||||
cached_member_openids = {}
|
||||
"""QQ官方 用户的id是字符串,而YiriMirai的用户id是整数,所以需要一个索引来进行转换"""
|
||||
|
||||
member_openid_index = 100
|
||||
def char_to_value(char):
|
||||
"""将单个字符转换为相应的数值。"""
|
||||
if '0' <= char <= '9':
|
||||
return ord(char) - ord('0')
|
||||
elif 'A' <= char <= 'Z':
|
||||
return ord(char) - ord('A') + 10
|
||||
|
||||
def save_member_openid(member_openid: str) -> int:
|
||||
"""保存用户id"""
|
||||
global member_openid_index, cached_member_openids
|
||||
return ord(char) - ord('a') + 36
|
||||
|
||||
if member_openid in cached_member_openids.values():
|
||||
return list(cached_member_openids.keys())[list(cached_member_openids.values()).index(member_openid)]
|
||||
def digest(s: str) -> int:
|
||||
"""计算字符串的hash值。"""
|
||||
# 取末尾的8位
|
||||
sub_s = s[-10:]
|
||||
|
||||
crt_index = member_openid_index
|
||||
member_openid_index += 1
|
||||
cached_member_openids[str(crt_index)] = member_openid
|
||||
return crt_index
|
||||
number = 0
|
||||
base = 36
|
||||
|
||||
cached_group_openids = {}
|
||||
"""QQ官方 群组的id是字符串,而YiriMirai的群组id是整数,所以需要一个索引来进行转换"""
|
||||
for i in range(len(sub_s)):
|
||||
number = number * base + char_to_value(sub_s[i])
|
||||
|
||||
group_openid_index = 1000
|
||||
return number
|
||||
|
||||
def save_group_openid(group_openid: str) -> int:
|
||||
"""保存群组id"""
|
||||
global group_openid_index, cached_group_openids
|
||||
K = typing.TypeVar("K")
|
||||
V = typing.TypeVar("V")
|
||||
|
||||
if group_openid in cached_group_openids.values():
|
||||
return list(cached_group_openids.keys())[list(cached_group_openids.values()).index(group_openid)]
|
||||
|
||||
crt_index = group_openid_index
|
||||
group_openid_index += 1
|
||||
cached_group_openids[str(crt_index)] = group_openid
|
||||
return crt_index
|
||||
class OpenIDMapping(typing.Generic[K, V]):
|
||||
|
||||
map: dict[K, V]
|
||||
|
||||
dump_func: typing.Callable
|
||||
|
||||
digest_func: typing.Callable[[K], V]
|
||||
|
||||
def __init__(self, map: dict[K, V], dump_func: typing.Callable, digest_func: typing.Callable[[K], V] = digest):
|
||||
self.map = map
|
||||
|
||||
self.dump_func = dump_func
|
||||
|
||||
self.digest_func = digest_func
|
||||
|
||||
def __getitem__(self, key: K) -> V:
|
||||
return self.map[key]
|
||||
|
||||
def __setitem__(self, key: K, value: V):
|
||||
self.map[key] = value
|
||||
self.dump_func()
|
||||
|
||||
def __contains__(self, key: K) -> bool:
|
||||
return key in self.map
|
||||
|
||||
def __delitem__(self, key: K):
|
||||
del self.map[key]
|
||||
self.dump_func()
|
||||
|
||||
def getkey(self, value: V) -> K:
|
||||
return list(self.map.keys())[list(self.map.values()).index(value)]
|
||||
|
||||
def save_openid(self, key: K) -> V:
|
||||
|
||||
if key in self.map:
|
||||
return self.map[key]
|
||||
|
||||
value = self.digest_func(key)
|
||||
|
||||
self.map[key] = value
|
||||
|
||||
self.dump_func()
|
||||
|
||||
return value
|
||||
|
||||
|
||||
class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
"""QQ 官方消息转换器"""
|
||||
|
||||
@staticmethod
|
||||
def yiri2target(message_chain: mirai.MessageChain):
|
||||
"""将 YiriMirai 的消息链转换为 QQ 官方消息"""
|
||||
@@ -89,8 +130,12 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
msg_list = message_chain.__root__
|
||||
elif type(message_chain) is list:
|
||||
msg_list = message_chain
|
||||
elif type(message_chain) is str:
|
||||
msg_list = [mirai.Plain(text=message_chain)]
|
||||
else:
|
||||
raise Exception("Unknown message type: " + str(message_chain) + str(type(message_chain)))
|
||||
raise Exception(
|
||||
"Unknown message type: " + str(message_chain) + str(type(message_chain))
|
||||
)
|
||||
|
||||
offcial_messages: list[dict] = []
|
||||
"""
|
||||
@@ -108,36 +153,24 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
# 遍历并转换
|
||||
for component in msg_list:
|
||||
if type(component) is mirai.Plain:
|
||||
offcial_messages.append({
|
||||
"type": "text",
|
||||
"content": component.text
|
||||
})
|
||||
offcial_messages.append({"type": "text", "content": component.text})
|
||||
elif type(component) is mirai.Image:
|
||||
if component.url is not None:
|
||||
offcial_messages.append(
|
||||
{
|
||||
"type": "image",
|
||||
"content": component.url
|
||||
}
|
||||
)
|
||||
offcial_messages.append({"type": "image", "content": component.url})
|
||||
elif component.path is not None:
|
||||
offcial_messages.append(
|
||||
{
|
||||
"type": "file_image",
|
||||
"content": component.path
|
||||
}
|
||||
{"type": "file_image", "content": component.path}
|
||||
)
|
||||
elif type(component) is mirai.At:
|
||||
offcial_messages.append(
|
||||
{
|
||||
"type": "at",
|
||||
"content": ""
|
||||
}
|
||||
)
|
||||
offcial_messages.append({"type": "at", "content": ""})
|
||||
elif type(component) is mirai.AtAll:
|
||||
print("上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。")
|
||||
print(
|
||||
"上层组件要求发送 AtAll 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
|
||||
)
|
||||
elif type(component) is mirai.Voice:
|
||||
print("上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。")
|
||||
print(
|
||||
"上层组件要求发送 Voice 消息,但 QQ 官方 API 不支持此消息类型,忽略此消息。"
|
||||
)
|
||||
elif type(component) is forward.Forward:
|
||||
# 转发消息
|
||||
yiri_forward_node_list = component.node_list
|
||||
@@ -148,20 +181,31 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
message_chain = yiri_forward_node.message_chain
|
||||
|
||||
# 平铺
|
||||
offcial_messages.extend(OfficialMessageConverter.yiri2target(message_chain))
|
||||
offcial_messages.extend(
|
||||
OfficialMessageConverter.yiri2target(message_chain)
|
||||
)
|
||||
except Exception as e:
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
return offcial_messages
|
||||
|
||||
@staticmethod
|
||||
def extract_message_chain_from_obj(message: typing.Union[botpy_message.Message, botpy_message.DirectMessage], message_id: str = None, bot_account_id: int = 0) -> mirai.MessageChain:
|
||||
def extract_message_chain_from_obj(
|
||||
message: typing.Union[botpy_message.Message, botpy_message.DirectMessage],
|
||||
message_id: str = None,
|
||||
bot_account_id: int = 0,
|
||||
) -> mirai.MessageChain:
|
||||
yiri_msg_list = []
|
||||
|
||||
# 存id
|
||||
|
||||
yiri_msg_list.append(mirai.models.message.Source(id=save_msg_id(message_id), time=datetime.datetime.now()))
|
||||
yiri_msg_list.append(
|
||||
mirai.models.message.Source(
|
||||
id=save_msg_id(message_id), time=datetime.datetime.now()
|
||||
)
|
||||
)
|
||||
|
||||
if type(message) is not botpy_message.DirectMessage:
|
||||
yiri_msg_list.append(mirai.At(target=bot_account_id))
|
||||
@@ -177,7 +221,9 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
if attachment.content_type == "image":
|
||||
yiri_msg_list.append(mirai.Image(url=attachment.url))
|
||||
else:
|
||||
logging.warning("不支持的附件类型:" + attachment.content_type + ",忽略此附件。")
|
||||
logging.warning(
|
||||
"不支持的附件类型:" + attachment.content_type + ",忽略此附件。"
|
||||
)
|
||||
|
||||
content = re.sub(r"<@!\d+>", "", str(message.content))
|
||||
if content.strip() != "":
|
||||
@@ -190,25 +236,36 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
||||
|
||||
class OfficialEventConverter(adapter_model.EventConverter):
|
||||
"""事件转换器"""
|
||||
@staticmethod
|
||||
def yiri2target(event: typing.Type[mirai.Event]):
|
||||
|
||||
member_openid_mapping: OpenIDMapping[str, int]
|
||||
group_openid_mapping: OpenIDMapping[str, int]
|
||||
|
||||
def __init__(self, member_openid_mapping: OpenIDMapping[str, int], group_openid_mapping: OpenIDMapping[str, int]):
|
||||
self.member_openid_mapping = member_openid_mapping
|
||||
self.group_openid_mapping = group_openid_mapping
|
||||
|
||||
def yiri2target(self, event: typing.Type[mirai.Event]):
|
||||
if event == mirai.GroupMessage:
|
||||
return botpy_message.Message
|
||||
elif event == mirai.FriendMessage:
|
||||
return botpy_message.DirectMessage
|
||||
else:
|
||||
raise Exception("未支持转换的事件类型(YiriMirai -> Official): " + str(event))
|
||||
raise Exception(
|
||||
"未支持转换的事件类型(YiriMirai -> Official): " + str(event)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def target2yiri(event: typing.Union[botpy_message.Message, botpy_message.DirectMessage]) -> mirai.Event:
|
||||
def target2yiri(
|
||||
self,
|
||||
event: typing.Union[botpy_message.Message, botpy_message.DirectMessage]
|
||||
) -> mirai.Event:
|
||||
import mirai.models.entities as mirai_entities
|
||||
|
||||
if type(event) == botpy_message.Message: # 频道内,转群聊事件
|
||||
permission = "MEMBER"
|
||||
|
||||
if '2' in event.member.roles:
|
||||
if "2" in event.member.roles:
|
||||
permission = "ADMINISTRATOR"
|
||||
elif '4' in event.member.roles:
|
||||
elif "4" in event.member.roles:
|
||||
permission = "OWNER"
|
||||
|
||||
return mirai.GroupMessage(
|
||||
@@ -219,15 +276,25 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
||||
group=mirai_entities.Group(
|
||||
id=event.channel_id,
|
||||
name=event.author.username,
|
||||
permission=mirai_entities.Permission.Member
|
||||
permission=mirai_entities.Permission.Member,
|
||||
),
|
||||
special_title="",
|
||||
join_timestamp=int(
|
||||
datetime.datetime.strptime(
|
||||
event.member.joined_at, "%Y-%m-%dT%H:%M:%S%z"
|
||||
).timestamp()
|
||||
),
|
||||
special_title='',
|
||||
join_timestamp=int(datetime.datetime.strptime(event.member.joined_at, "%Y-%m-%dT%H:%M:%S%z").timestamp()),
|
||||
last_speak_timestamp=datetime.datetime.now().timestamp(),
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
|
||||
time=int(datetime.datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S%z").timestamp()),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
|
||||
event, event.id
|
||||
),
|
||||
time=int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
|
||||
).timestamp()
|
||||
),
|
||||
)
|
||||
elif type(event) == botpy_message.DirectMessage: # 私聊,转私聊事件
|
||||
return mirai.FriendMessage(
|
||||
@@ -236,12 +303,18 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
||||
nickname=event.author.username,
|
||||
remark=event.author.username,
|
||||
),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
|
||||
time=int(datetime.datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S%z").timestamp()),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
|
||||
event, event.id
|
||||
),
|
||||
time=int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
|
||||
).timestamp()
|
||||
),
|
||||
)
|
||||
elif type(event) == botpy_message.GroupMessage:
|
||||
|
||||
replacing_member_id = save_member_openid(event.author.member_openid)
|
||||
replacing_member_id = self.member_openid_mapping.save_openid(event.author.member_openid)
|
||||
|
||||
return OfficialGroupMessage(
|
||||
sender=mirai_entities.GroupMember(
|
||||
@@ -249,29 +322,36 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
||||
member_name=replacing_member_id,
|
||||
permission="MEMBER",
|
||||
group=mirai_entities.Group(
|
||||
id=save_group_openid(event.group_openid),
|
||||
id=self.group_openid_mapping.save_openid(event.group_openid),
|
||||
name=replacing_member_id,
|
||||
permission=mirai_entities.Permission.Member
|
||||
permission=mirai_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
special_title="",
|
||||
join_timestamp=int(0),
|
||||
last_speak_timestamp=datetime.datetime.now().timestamp(),
|
||||
mute_time_remaining=0,
|
||||
),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(event, event.id),
|
||||
time=int(datetime.datetime.strptime(event.timestamp, "%Y-%m-%dT%H:%M:%S%z").timestamp()),
|
||||
message_chain=OfficialMessageConverter.extract_message_chain_from_obj(
|
||||
event, event.id
|
||||
),
|
||||
time=int(
|
||||
datetime.datetime.strptime(
|
||||
event.timestamp, "%Y-%m-%dT%H:%M:%S%z"
|
||||
).timestamp()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@adapter_model.adapter_class("qq-botpy")
|
||||
class OfficialAdapter(adapter_model.MessageSourceAdapter):
|
||||
"""QQ 官方消息适配器"""
|
||||
|
||||
bot: botpy.Client = None
|
||||
|
||||
bot_account_id: int = 0
|
||||
|
||||
message_converter: OfficialMessageConverter = OfficialMessageConverter()
|
||||
# event_handler: adapter_model.EventHandler = adapter_model.EventHandler()
|
||||
message_converter: OfficialMessageConverter
|
||||
event_converter: OfficialEventConverter
|
||||
|
||||
cfg: dict = None
|
||||
|
||||
@@ -283,78 +363,110 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
|
||||
|
||||
ap: app.Application
|
||||
|
||||
metadata: cfg_mgr.ConfigManager = None
|
||||
|
||||
member_openid_mapping: OpenIDMapping[str, int] = None
|
||||
group_openid_mapping: OpenIDMapping[str, int] = None
|
||||
|
||||
group_msg_seq = None
|
||||
|
||||
def __init__(self, cfg: dict, ap: app.Application):
|
||||
"""初始化适配器"""
|
||||
self.cfg = cfg
|
||||
self.ap = ap
|
||||
|
||||
self.group_msg_seq = 1
|
||||
|
||||
switchs = {}
|
||||
|
||||
for intent in cfg['intents']:
|
||||
for intent in cfg["intents"]:
|
||||
switchs[intent] = True
|
||||
|
||||
del cfg['intents']
|
||||
del cfg["intents"]
|
||||
|
||||
intents = botpy.Intents(**switchs)
|
||||
|
||||
self.bot = botpy.Client(intents=intents)
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
target_type: str,
|
||||
target_id: str,
|
||||
message: mirai.MessageChain
|
||||
self, target_type: str, target_id: str, message: mirai.MessageChain
|
||||
):
|
||||
pass
|
||||
message_list = self.message_converter.yiri2target(message)
|
||||
|
||||
for msg in message_list:
|
||||
args = {}
|
||||
|
||||
if msg["type"] == "text":
|
||||
args["content"] = msg["content"]
|
||||
elif msg["type"] == "image":
|
||||
args["image"] = msg["content"]
|
||||
elif msg["type"] == "file_image":
|
||||
args["file_image"] = msg["content"]
|
||||
else:
|
||||
continue
|
||||
|
||||
if target_type == "group":
|
||||
args["channel_id"] = str(target_id)
|
||||
|
||||
await self.bot.api.post_message(**args)
|
||||
elif target_type == "person":
|
||||
args["guild_id"] = str(target_id)
|
||||
|
||||
await self.bot.api.post_dms(**args)
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: mirai.MessageEvent,
|
||||
message: mirai.MessageChain,
|
||||
quote_origin: bool = False
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
message_list = self.message_converter.yiri2target(message)
|
||||
tasks = []
|
||||
|
||||
msg_seq = 1
|
||||
message_list = self.message_converter.yiri2target(message)
|
||||
|
||||
for msg in message_list:
|
||||
args = {}
|
||||
|
||||
if msg['type'] == 'text':
|
||||
args['content'] = msg['content']
|
||||
elif msg['type'] == 'image':
|
||||
args['image'] = msg['content']
|
||||
elif msg['type'] == 'file_image':
|
||||
args['file_image'] = msg["content"]
|
||||
if msg["type"] == "text":
|
||||
args["content"] = msg["content"]
|
||||
elif msg["type"] == "image":
|
||||
args["image"] = msg["content"]
|
||||
elif msg["type"] == "file_image":
|
||||
args["file_image"] = msg["content"]
|
||||
else:
|
||||
continue
|
||||
|
||||
if quote_origin:
|
||||
args['message_reference'] = botpy_message_type.Reference(message_id=cached_message_ids[str(message_source.message_chain.message_id)])
|
||||
|
||||
if type(message_source) == mirai.GroupMessage:
|
||||
args['channel_id'] = str(message_source.sender.group.id)
|
||||
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
|
||||
await self.bot.api.post_message(**args)
|
||||
elif type(message_source) == mirai.FriendMessage:
|
||||
args['guild_id'] = str(message_source.sender.id)
|
||||
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
|
||||
await self.bot.api.post_dms(**args)
|
||||
elif type(message_source) == OfficialGroupMessage:
|
||||
# args['guild_id'] = str(message_source.sender.group.id)
|
||||
# args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
|
||||
# await self.bot.api.post_message(**args)
|
||||
if 'image' in args or 'file_image' in args:
|
||||
continue
|
||||
args['group_openid'] = cached_group_openids[str(message_source.sender.group.id)]
|
||||
args['msg_id'] = cached_message_ids[str(message_source.message_chain.message_id)]
|
||||
args['msg_seq'] = msg_seq
|
||||
msg_seq += 1
|
||||
await self.bot.api.post_group_message(
|
||||
**args
|
||||
args["message_reference"] = botpy_message_type.Reference(
|
||||
message_id=cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
)
|
||||
|
||||
if type(message_source) == mirai.GroupMessage:
|
||||
args["channel_id"] = str(message_source.sender.group.id)
|
||||
args["msg_id"] = cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
await self.bot.api.post_message(**args)
|
||||
elif type(message_source) == mirai.FriendMessage:
|
||||
args["guild_id"] = str(message_source.sender.id)
|
||||
args["msg_id"] = cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
await self.bot.api.post_dms(**args)
|
||||
elif type(message_source) == OfficialGroupMessage:
|
||||
if "image" in args or "file_image" in args:
|
||||
continue
|
||||
args["group_openid"] = self.group_openid_mapping.getkey(
|
||||
message_source.sender.group.id
|
||||
)
|
||||
|
||||
args["msg_id"] = cached_message_ids[
|
||||
str(message_source.message_chain.message_id)
|
||||
]
|
||||
args["msg_seq"] = self.group_msg_seq
|
||||
self.group_msg_seq += 1
|
||||
await self.bot.api.post_group_message(**args)
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
return False
|
||||
@@ -362,14 +474,22 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
|
||||
callback: typing.Callable[
|
||||
[mirai.Event, adapter_model.MessageSourceAdapter], None
|
||||
],
|
||||
):
|
||||
|
||||
try:
|
||||
|
||||
async def wrapper(message: typing.Union[botpy_message.Message, botpy_message.DirectMessage, botpy_message.GroupMessage]):
|
||||
async def wrapper(
|
||||
message: typing.Union[
|
||||
botpy_message.Message,
|
||||
botpy_message.DirectMessage,
|
||||
botpy_message.GroupMessage,
|
||||
]
|
||||
):
|
||||
self.cached_official_messages[str(message.id)] = message
|
||||
await callback(OfficialEventConverter.target2yiri(message), self)
|
||||
await callback(self.event_converter.target2yiri(message), self)
|
||||
|
||||
for event_handler in event_handler_mapping[event_type]:
|
||||
setattr(self.bot, event_handler, wrapper)
|
||||
@@ -380,15 +500,33 @@ class OfficialAdapter(adapter_model.MessageSourceAdapter):
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[mirai.Event],
|
||||
callback: typing.Callable[[mirai.Event, adapter_model.MessageSourceAdapter], None]
|
||||
callback: typing.Callable[
|
||||
[mirai.Event, adapter_model.MessageSourceAdapter], None
|
||||
],
|
||||
):
|
||||
delattr(self.bot, event_handler_mapping[event_type])
|
||||
|
||||
async def run_async(self):
|
||||
self.ap.logger.info("运行 QQ 官方适配器")
|
||||
await self.bot.start(
|
||||
**self.cfg
|
||||
|
||||
self.metadata = self.ap.adapter_qq_botpy_meta
|
||||
|
||||
self.member_openid_mapping = OpenIDMapping(
|
||||
map=self.metadata.data["mapping"]["members"],
|
||||
dump_func=self.metadata.dump_config_sync,
|
||||
)
|
||||
|
||||
self.group_openid_mapping = OpenIDMapping(
|
||||
map=self.metadata.data["mapping"]["groups"],
|
||||
dump_func=self.metadata.dump_config_sync,
|
||||
)
|
||||
|
||||
self.message_converter = OfficialMessageConverter()
|
||||
self.event_converter = OfficialEventConverter(
|
||||
self.member_openid_mapping, self.group_openid_mapping
|
||||
)
|
||||
|
||||
self.ap.logger.info("运行 QQ 官方适配器")
|
||||
await self.bot.start(**self.cfg)
|
||||
|
||||
def kill(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -9,10 +9,86 @@ from ..provider.tools import entities as tools_entities
|
||||
from ..core import app
|
||||
|
||||
|
||||
def register(
|
||||
name: str,
|
||||
description: str,
|
||||
version: str,
|
||||
author: str
|
||||
) -> typing.Callable[[typing.Type[BasePlugin]], typing.Type[BasePlugin]]:
|
||||
"""注册插件类
|
||||
|
||||
使用示例:
|
||||
|
||||
@register(
|
||||
name="插件名称",
|
||||
description="插件描述",
|
||||
version="插件版本",
|
||||
author="插件作者"
|
||||
)
|
||||
class MyPlugin(BasePlugin):
|
||||
pass
|
||||
"""
|
||||
pass
|
||||
|
||||
def handler(
|
||||
event: typing.Type[events.BaseEventModel]
|
||||
) -> typing.Callable[[typing.Callable], typing.Callable]:
|
||||
"""注册事件监听器
|
||||
|
||||
使用示例:
|
||||
|
||||
class MyPlugin(BasePlugin):
|
||||
|
||||
@handler(NormalMessageResponded)
|
||||
async def on_normal_message_responded(self, ctx: EventContext):
|
||||
pass
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
def llm_func(
|
||||
name: str=None,
|
||||
) -> typing.Callable:
|
||||
"""注册内容函数
|
||||
|
||||
使用示例:
|
||||
|
||||
class MyPlugin(BasePlugin):
|
||||
|
||||
@llm_func("access_the_web_page")
|
||||
async def _(self, query, url: str, brief_len: int):
|
||||
\"""Call this function to search about the question before you answer any questions.
|
||||
- Do not search through google.com at any time.
|
||||
- If you need to search somthing, visit https://www.sogou.com/web?query=<something>.
|
||||
- If user ask you to open a url (start with http:// or https://), visit it directly.
|
||||
- Summary the plain content result by yourself, DO NOT directly output anything in the result you got.
|
||||
|
||||
Args:
|
||||
url(str): url to visit
|
||||
brief_len(int): max length of the plain text content, recommend 1024-4096, prefer 4096
|
||||
|
||||
Returns:
|
||||
str: plain text content of the web page or error message(starts with 'error:')
|
||||
\"""
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class BasePlugin(metaclass=abc.ABCMeta):
|
||||
"""插件基类"""
|
||||
|
||||
host: APIHost
|
||||
"""API宿主"""
|
||||
|
||||
ap: app.Application
|
||||
"""应用程序对象"""
|
||||
|
||||
def __init__(self, host: APIHost):
|
||||
self.host = host
|
||||
|
||||
async def initialize(self):
|
||||
"""初始化插件"""
|
||||
pass
|
||||
|
||||
|
||||
class APIHost:
|
||||
@@ -61,8 +137,10 @@ class EventContext:
|
||||
"""事件编号"""
|
||||
|
||||
host: APIHost = None
|
||||
"""API宿主"""
|
||||
|
||||
event: events.BaseEventModel = None
|
||||
"""此次事件的对象,具体类型为handler注册时指定监听的类型,可查看events.py中的定义"""
|
||||
|
||||
__prevent_default__ = False
|
||||
"""是否阻止默认行为"""
|
||||
|
||||
@@ -10,8 +10,10 @@ from ..provider import entities as llm_entities
|
||||
|
||||
|
||||
class BaseEventModel(pydantic.BaseModel):
|
||||
"""事件模型基类"""
|
||||
|
||||
query: typing.Union[core_entities.Query, None]
|
||||
"""此次请求的query对象,非请求过程的事件时为None"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
# 此模块已过时
|
||||
# 请从 pkg.plugin.context 引入 BasePlugin, EventContext 和 APIHost
|
||||
# 最早将于 v3.4 移除此模块
|
||||
|
||||
from . events import *
|
||||
from . context import EventContext, APIHost as PluginHost
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from ..core import app
|
||||
|
||||
|
||||
class PluginInstaller(metaclass=abc.ABCMeta):
|
||||
"""插件安装器抽象类"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
|
||||
@@ -12,6 +12,8 @@ from ...utils import pkgmgr
|
||||
|
||||
|
||||
class GitHubRepoInstaller(installer.PluginInstaller):
|
||||
"""GitHub仓库插件安装器
|
||||
"""
|
||||
|
||||
def get_github_plugin_repo_label(self, repo_url: str) -> list[str]:
|
||||
"""获取username, repo"""
|
||||
|
||||
@@ -9,7 +9,7 @@ from . import context, events
|
||||
|
||||
|
||||
class PluginLoader(metaclass=abc.ABCMeta):
|
||||
"""插件加载器"""
|
||||
"""插件加载器抽象类"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
|
||||
@@ -5,11 +5,10 @@ import pkgutil
|
||||
import importlib
|
||||
import traceback
|
||||
|
||||
from CallingGPT.entities.namespace import get_func_schema
|
||||
|
||||
from .. import loader, events, context, models, host
|
||||
from ...core import entities as core_entities
|
||||
from ...provider.tools import entities as tools_entities
|
||||
from ...utils import funcschema
|
||||
|
||||
|
||||
class PluginLoader(loader.PluginLoader):
|
||||
@@ -29,6 +28,10 @@ class PluginLoader(loader.PluginLoader):
|
||||
setattr(models, 'on', self.on)
|
||||
setattr(models, 'func', self.func)
|
||||
|
||||
setattr(context, 'register', self.register)
|
||||
setattr(context, 'handler', self.handler)
|
||||
setattr(context, 'llm_func', self.llm_func)
|
||||
|
||||
def register(
|
||||
self,
|
||||
name: str,
|
||||
@@ -57,6 +60,8 @@ class PluginLoader(loader.PluginLoader):
|
||||
|
||||
return wrapper
|
||||
|
||||
# 过时
|
||||
# 最早将于 v3.4 版本移除
|
||||
def on(
|
||||
self,
|
||||
event: typing.Type[events.BaseEventModel]
|
||||
@@ -83,6 +88,8 @@ class PluginLoader(loader.PluginLoader):
|
||||
|
||||
return wrapper
|
||||
|
||||
# 过时
|
||||
# 最早将于 v3.4 版本移除
|
||||
def func(
|
||||
self,
|
||||
name: str=None,
|
||||
@@ -91,10 +98,11 @@ class PluginLoader(loader.PluginLoader):
|
||||
self.ap.logger.debug(f'注册内容函数 {name}')
|
||||
def wrapper(func: typing.Callable) -> typing.Callable:
|
||||
|
||||
function_schema = get_func_schema(func)
|
||||
function_schema = funcschema.get_func_schema(func)
|
||||
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
|
||||
|
||||
async def handler(
|
||||
plugin: context.BasePlugin,
|
||||
query: core_entities.Query,
|
||||
*args,
|
||||
**kwargs
|
||||
@@ -116,6 +124,46 @@ class PluginLoader(loader.PluginLoader):
|
||||
|
||||
return wrapper
|
||||
|
||||
def handler(
|
||||
self,
|
||||
event: typing.Type[events.BaseEventModel]
|
||||
) -> typing.Callable[[typing.Callable], typing.Callable]:
|
||||
"""注册事件处理器"""
|
||||
self.ap.logger.debug(f'注册事件处理器 {event.__name__}')
|
||||
def wrapper(func: typing.Callable) -> typing.Callable:
|
||||
|
||||
self._current_container.event_handlers[event] = func
|
||||
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
def llm_func(
|
||||
self,
|
||||
name: str=None,
|
||||
) -> typing.Callable:
|
||||
"""注册内容函数"""
|
||||
self.ap.logger.debug(f'注册内容函数 {name}')
|
||||
def wrapper(func: typing.Callable) -> typing.Callable:
|
||||
|
||||
function_schema = funcschema.get_func_schema(func)
|
||||
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
|
||||
|
||||
llm_function = tools_entities.LLMFunction(
|
||||
name=function_name,
|
||||
human_desc='',
|
||||
description=function_schema['description'],
|
||||
enable=True,
|
||||
parameters=function_schema['parameters'],
|
||||
func=func,
|
||||
)
|
||||
|
||||
self._current_container.content_functions.append(llm_function)
|
||||
|
||||
return func
|
||||
|
||||
return wrapper
|
||||
|
||||
async def _walk_plugin_path(
|
||||
self,
|
||||
module,
|
||||
@@ -5,11 +5,12 @@ import traceback
|
||||
|
||||
from ..core import app
|
||||
from . import context, loader, events, installer, setting, models
|
||||
from .loaders import legacy
|
||||
from .loaders import classic
|
||||
from .installers import github
|
||||
|
||||
|
||||
class PluginManager:
|
||||
"""插件管理器"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
@@ -25,7 +26,7 @@ class PluginManager:
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.loader = legacy.PluginLoader(ap)
|
||||
self.loader = classic.PluginLoader(ap)
|
||||
self.installer = github.GitHubRepoInstaller(ap)
|
||||
self.setting = setting.SettingManager(ap)
|
||||
self.api_host = context.APIHost(ap)
|
||||
@@ -51,6 +52,9 @@ class PluginManager:
|
||||
for plugin in self.plugins:
|
||||
try:
|
||||
plugin.plugin_inst = plugin.plugin_class(self.api_host)
|
||||
plugin.plugin_inst.ap = self.ap
|
||||
plugin.plugin_inst.host = self.api_host
|
||||
await plugin.plugin_inst.initialize()
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'插件 {plugin.plugin_name} 初始化失败: {e}')
|
||||
self.ap.logger.exception(e)
|
||||
@@ -136,8 +140,7 @@ class PluginManager:
|
||||
for plugin in self.plugins:
|
||||
if plugin.enabled:
|
||||
if event.__class__ in plugin.event_handlers:
|
||||
|
||||
emitted_plugins.append(plugin)
|
||||
self.ap.logger.debug(f'插件 {plugin.plugin_name} 触发事件 {event.__class__.__name__}')
|
||||
|
||||
is_prevented_default_before_call = ctx.is_prevented_default()
|
||||
|
||||
@@ -150,6 +153,8 @@ class PluginManager:
|
||||
self.ap.logger.error(f'插件 {plugin.plugin_name} 触发事件 {event.__class__.__name__} 时发生错误: {e}')
|
||||
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
|
||||
|
||||
emitted_plugins.append(plugin)
|
||||
|
||||
if not is_prevented_default_before_call and ctx.is_prevented_default():
|
||||
self.ap.logger.debug(f'插件 {plugin.plugin_name} 阻止了默认行为执行')
|
||||
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
# 此模块已过时,请引入 pkg.plugin.context 中的 register, handler 和 llm_func 来注册插件、事件处理函数和内容函数
|
||||
# 各个事件模型请从 pkg.plugin.events 引入
|
||||
# 最早将于 v3.4 移除此模块
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
@@ -6,6 +6,7 @@ from . import context
|
||||
|
||||
|
||||
class SettingManager:
|
||||
"""插件设置管理器"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
@@ -15,10 +16,7 @@ class SettingManager:
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
self.settings = await cfg_mgr.load_json_config(
|
||||
'plugins/plugins.json',
|
||||
'templates/plugin-settings.json'
|
||||
)
|
||||
self.settings = self.ap.plugin_setting_meta
|
||||
|
||||
async def sync_setting(
|
||||
self,
|
||||
|
||||
@@ -4,6 +4,8 @@ import typing
|
||||
import enum
|
||||
import pydantic
|
||||
|
||||
import mirai
|
||||
|
||||
|
||||
class FunctionCall(pydantic.BaseModel):
|
||||
name: str
|
||||
@@ -19,25 +21,99 @@ class ToolCall(pydantic.BaseModel):
|
||||
function: FunctionCall
|
||||
|
||||
|
||||
class ImageURLContentObject(pydantic.BaseModel):
|
||||
url: str
|
||||
|
||||
def __str__(self):
|
||||
return self.url[:128] + ('...' if len(self.url) > 128 else '')
|
||||
|
||||
|
||||
class ContentElement(pydantic.BaseModel):
|
||||
|
||||
type: str
|
||||
"""内容类型"""
|
||||
|
||||
text: typing.Optional[str] = None
|
||||
|
||||
image_url: typing.Optional[ImageURLContentObject] = None
|
||||
|
||||
def __str__(self):
|
||||
if self.type == 'text':
|
||||
return self.text
|
||||
elif self.type == 'image_url':
|
||||
return f'[图片]({self.image_url})'
|
||||
else:
|
||||
return '未知内容'
|
||||
|
||||
@classmethod
|
||||
def from_text(cls, text: str):
|
||||
return cls(type='text', text=text)
|
||||
|
||||
@classmethod
|
||||
def from_image_url(cls, image_url: str):
|
||||
return cls(type='image_url', image_url=ImageURLContentObject(url=image_url))
|
||||
|
||||
|
||||
class Message(pydantic.BaseModel):
|
||||
role: str # user, system, assistant, tool, command
|
||||
"""消息"""
|
||||
|
||||
role: str # user, system, assistant, tool, command, plugin
|
||||
"""消息的角色"""
|
||||
|
||||
name: typing.Optional[str] = None
|
||||
"""名称,仅函数调用返回时设置"""
|
||||
|
||||
content: typing.Optional[str] = None
|
||||
|
||||
function_call: typing.Optional[FunctionCall] = None
|
||||
content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None
|
||||
"""内容"""
|
||||
|
||||
tool_calls: typing.Optional[list[ToolCall]] = None
|
||||
"""工具调用"""
|
||||
|
||||
tool_call_id: typing.Optional[str] = None
|
||||
|
||||
def readable_str(self) -> str:
|
||||
if self.content is not None:
|
||||
return self.content
|
||||
elif self.function_call is not None:
|
||||
return f'{self.function_call.name}({self.function_call.arguments})'
|
||||
return str(self.role) + ": " + str(self.get_content_mirai_message_chain())
|
||||
elif self.tool_calls is not None:
|
||||
return f'调用工具: {self.tool_calls[0].id}'
|
||||
else:
|
||||
return '未知消息'
|
||||
|
||||
def get_content_mirai_message_chain(self, prefix_text: str="") -> mirai.MessageChain | None:
|
||||
"""将内容转换为 Mirai MessageChain 对象
|
||||
|
||||
Args:
|
||||
prefix_text (str): 首个文字组件的前缀文本
|
||||
"""
|
||||
|
||||
if self.content is None:
|
||||
return None
|
||||
elif isinstance(self.content, str):
|
||||
return mirai.MessageChain([mirai.Plain(prefix_text+self.content)])
|
||||
elif isinstance(self.content, list):
|
||||
mc = []
|
||||
for ce in self.content:
|
||||
if ce.type == 'text':
|
||||
mc.append(mirai.Plain(ce.text))
|
||||
elif ce.type == 'image':
|
||||
if ce.image_url.url.startswith("http"):
|
||||
mc.append(mirai.Image(url=ce.image_url.url))
|
||||
else: # base64
|
||||
|
||||
b64_str = ce.image_url.url
|
||||
|
||||
if b64_str.startswith("data:"):
|
||||
b64_str = b64_str.split(",")[1]
|
||||
|
||||
mc.append(mirai.Image(base64=b64_str))
|
||||
|
||||
# 找第一个文字组件
|
||||
if prefix_text:
|
||||
for i, c in enumerate(mc):
|
||||
if isinstance(c, mirai.Plain):
|
||||
mc[i] = mirai.Plain(prefix_text+c.text)
|
||||
break
|
||||
else:
|
||||
mc.insert(0, mirai.Plain(prefix_text))
|
||||
|
||||
return mirai.MessageChain(mc)
|
||||
|
||||
0
pkg/provider/modelmgr/__init__.py
Normal file
0
pkg/provider/modelmgr/__init__.py
Normal file
65
pkg/provider/modelmgr/api.py
Normal file
65
pkg/provider/modelmgr/api.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
from . import entities as modelmgr_entities
|
||||
from ..tools import entities as tools_entities
|
||||
|
||||
|
||||
preregistered_requesters: list[typing.Type[LLMAPIRequester]] = []
|
||||
|
||||
def requester_class(name: str):
|
||||
|
||||
def decorator(cls: typing.Type[LLMAPIRequester]) -> typing.Type[LLMAPIRequester]:
|
||||
cls.name = name
|
||||
preregistered_requesters.append(cls)
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
class LLMAPIRequester(metaclass=abc.ABCMeta):
|
||||
"""LLM API请求器
|
||||
"""
|
||||
name: str = None
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def preprocess(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
):
|
||||
"""预处理
|
||||
|
||||
在这里处理特定API对Query对象的兼容性问题。
|
||||
"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def call(
|
||||
self,
|
||||
model: modelmgr_entities.LLMModelInfo,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
) -> llm_entities.Message:
|
||||
"""调用API
|
||||
|
||||
Args:
|
||||
model (modelmgr_entities.LLMModelInfo): 使用的模型信息
|
||||
messages (typing.List[llm_entities.Message]): 消息对象列表
|
||||
funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None.
|
||||
|
||||
Returns:
|
||||
llm_entities.Message: 返回消息对象
|
||||
"""
|
||||
pass
|
||||
0
pkg/provider/modelmgr/apis/__init__.py
Normal file
0
pkg/provider/modelmgr/apis/__init__.py
Normal file
112
pkg/provider/modelmgr/apis/anthropicmsgs.py
Normal file
112
pkg/provider/modelmgr/apis/anthropicmsgs.py
Normal file
@@ -0,0 +1,112 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
import anthropic
|
||||
|
||||
from .. import api, entities, errors
|
||||
|
||||
from .. import api, entities, errors
|
||||
from ....core import entities as core_entities
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
from ....utils import image
|
||||
|
||||
|
||||
@api.requester_class("anthropic-messages")
|
||||
class AnthropicMessages(api.LLMAPIRequester):
|
||||
"""Anthropic Messages API 请求器"""
|
||||
|
||||
client: anthropic.AsyncAnthropic
|
||||
|
||||
async def initialize(self):
|
||||
self.client = anthropic.AsyncAnthropic(
|
||||
api_key="",
|
||||
base_url=self.ap.provider_cfg.data['requester']['anthropic-messages']['base-url'],
|
||||
timeout=self.ap.provider_cfg.data['requester']['anthropic-messages']['timeout'],
|
||||
proxies=self.ap.proxy_mgr.get_forward_proxies()
|
||||
)
|
||||
|
||||
async def call(
|
||||
self,
|
||||
model: entities.LLMModelInfo,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = model.token_mgr.get_token()
|
||||
|
||||
args = self.ap.provider_cfg.data['requester']['anthropic-messages']['args'].copy()
|
||||
args["model"] = model.name if model.model_name is None else model.model_name
|
||||
|
||||
# 处理消息
|
||||
|
||||
# system
|
||||
system_role_message = None
|
||||
|
||||
for i, m in enumerate(messages):
|
||||
if m.role == "system":
|
||||
system_role_message = m
|
||||
|
||||
messages.pop(i)
|
||||
break
|
||||
|
||||
if isinstance(system_role_message, llm_entities.Message) \
|
||||
and isinstance(system_role_message.content, str):
|
||||
args['system'] = system_role_message.content
|
||||
|
||||
req_messages = []
|
||||
|
||||
for m in messages:
|
||||
if isinstance(m.content, str) and m.content.strip() != "":
|
||||
req_messages.append(m.dict(exclude_none=True))
|
||||
elif isinstance(m.content, list):
|
||||
# m.content = [
|
||||
# c for c in m.content if c.type == "text"
|
||||
# ]
|
||||
|
||||
# if len(m.content) > 0:
|
||||
# req_messages.append(m.dict(exclude_none=True))
|
||||
|
||||
msg_dict = m.dict(exclude_none=True)
|
||||
|
||||
for i, ce in enumerate(m.content):
|
||||
if ce.type == "image_url":
|
||||
alter_image_ele = {
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/jpeg",
|
||||
"data": await image.qq_image_url_to_base64(ce.image_url.url)
|
||||
}
|
||||
}
|
||||
msg_dict["content"][i] = alter_image_ele
|
||||
|
||||
req_messages.append(msg_dict)
|
||||
|
||||
args["messages"] = req_messages
|
||||
|
||||
# anthropic的tools处在beta阶段,sdk不稳定,故暂时不支持
|
||||
#
|
||||
# if funcs:
|
||||
# tools = await self.ap.tool_mgr.generate_tools_for_openai(funcs)
|
||||
|
||||
# if tools:
|
||||
# args["tools"] = tools
|
||||
|
||||
try:
|
||||
resp = await self.client.messages.create(**args)
|
||||
|
||||
return llm_entities.Message(
|
||||
content=resp.content[0].text,
|
||||
role=resp.role
|
||||
)
|
||||
except anthropic.AuthenticationError as e:
|
||||
raise errors.RequesterError(f'api-key 无效: {e.message}')
|
||||
except anthropic.BadRequestError as e:
|
||||
raise errors.RequesterError(str(e.message))
|
||||
except anthropic.NotFoundError as e:
|
||||
if 'model: ' in str(e):
|
||||
raise errors.RequesterError(f'模型无效: {e.message}')
|
||||
else:
|
||||
raise errors.RequesterError(f'请求地址无效: {e.message}')
|
||||
142
pkg/provider/modelmgr/apis/chatcmpl.py
Normal file
142
pkg/provider/modelmgr/apis/chatcmpl.py
Normal file
@@ -0,0 +1,142 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import typing
|
||||
import json
|
||||
import base64
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import openai
|
||||
import openai.types.chat.chat_completion as chat_completion
|
||||
import httpx
|
||||
import aiohttp
|
||||
import async_lru
|
||||
|
||||
from .. import api, entities, errors
|
||||
from ....core import entities as core_entities, app
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
from ....utils import image
|
||||
|
||||
|
||||
@api.requester_class("openai-chat-completions")
|
||||
class OpenAIChatCompletions(api.LLMAPIRequester):
|
||||
"""OpenAI ChatCompletion API 请求器"""
|
||||
|
||||
client: openai.AsyncClient
|
||||
|
||||
requester_cfg: dict
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
self.requester_cfg = self.ap.provider_cfg.data['requester']['openai-chat-completions']
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
self.client = openai.AsyncClient(
|
||||
api_key="",
|
||||
base_url=self.requester_cfg['base-url'],
|
||||
timeout=self.requester_cfg['timeout'],
|
||||
http_client=httpx.AsyncClient(
|
||||
proxies=self.ap.proxy_mgr.get_forward_proxies()
|
||||
)
|
||||
)
|
||||
|
||||
async def _req(
|
||||
self,
|
||||
args: dict,
|
||||
) -> chat_completion.ChatCompletion:
|
||||
return await self.client.chat.completions.create(**args)
|
||||
|
||||
async def _make_msg(
|
||||
self,
|
||||
chat_completion: chat_completion.ChatCompletion,
|
||||
) -> llm_entities.Message:
|
||||
chatcmpl_message = chat_completion.choices[0].message.dict()
|
||||
|
||||
message = llm_entities.Message(**chatcmpl_message)
|
||||
|
||||
return message
|
||||
|
||||
async def _closure(
|
||||
self,
|
||||
req_messages: list[dict],
|
||||
use_model: entities.LLMModelInfo,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = self.requester_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages.copy()
|
||||
|
||||
# 检查vision
|
||||
for msg in messages:
|
||||
if 'content' in msg and isinstance(msg["content"], list):
|
||||
for me in msg["content"]:
|
||||
if me["type"] == "image_url":
|
||||
me["image_url"]['url'] = await self.get_base64_str(me["image_url"]['url'])
|
||||
|
||||
args["messages"] = messages
|
||||
|
||||
# 发送请求
|
||||
resp = await self._req(args)
|
||||
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp)
|
||||
|
||||
return message
|
||||
|
||||
async def call(
|
||||
self,
|
||||
model: entities.LLMModelInfo,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
) -> llm_entities.Message:
|
||||
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
|
||||
for m in messages:
|
||||
msg_dict = m.dict(exclude_none=True)
|
||||
content = msg_dict.get("content")
|
||||
if isinstance(content, list):
|
||||
# 检查 content 列表中是否每个部分都是文本
|
||||
if all(isinstance(part, dict) and part.get("type") == "text" for part in content):
|
||||
# 将所有文本部分合并为一个字符串
|
||||
msg_dict["content"] = "\n".join(part["text"] for part in content)
|
||||
req_messages.append(msg_dict)
|
||||
|
||||
try:
|
||||
return await self._closure(req_messages, model, funcs)
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
except openai.BadRequestError as e:
|
||||
if 'context_length_exceeded' in e.message:
|
||||
raise errors.RequesterError(f'上文过长,请重置会话: {e.message}')
|
||||
else:
|
||||
raise errors.RequesterError(f'请求参数错误: {e.message}')
|
||||
except openai.AuthenticationError as e:
|
||||
raise errors.RequesterError(f'无效的 api-key: {e.message}')
|
||||
except openai.NotFoundError as e:
|
||||
raise errors.RequesterError(f'请求路径错误: {e.message}')
|
||||
except openai.RateLimitError as e:
|
||||
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
|
||||
except openai.APIError as e:
|
||||
raise errors.RequesterError(f'请求错误: {e.message}')
|
||||
|
||||
@async_lru.alru_cache(maxsize=128)
|
||||
async def get_base64_str(
|
||||
self,
|
||||
original_url: str,
|
||||
) -> str:
|
||||
|
||||
base64_image = await image.qq_image_url_to_base64(original_url)
|
||||
|
||||
return f"data:image/jpeg;base64,{base64_image}"
|
||||
53
pkg/provider/modelmgr/apis/deepseekchatcmpl.py
Normal file
53
pkg/provider/modelmgr/apis/deepseekchatcmpl.py
Normal file
@@ -0,0 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ....core import app
|
||||
|
||||
from . import chatcmpl
|
||||
from .. import api, entities, errors
|
||||
from ....core import entities as core_entities, app
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
|
||||
|
||||
@api.requester_class("deepseek-chat-completions")
|
||||
class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
"""Deepseek ChatCompletion API 请求器"""
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.requester_cfg = ap.provider_cfg.data['requester']['deepseek-chat-completions']
|
||||
self.ap = ap
|
||||
|
||||
async def _closure(
|
||||
self,
|
||||
req_messages: list[dict],
|
||||
use_model: entities.LLMModelInfo,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = self.requester_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages
|
||||
|
||||
# deepseek 不支持多模态,把content都转换成纯文字
|
||||
for m in messages:
|
||||
if 'content' in m and isinstance(m["content"], list):
|
||||
m["content"] = " ".join([c["text"] for c in m["content"]])
|
||||
|
||||
args["messages"] = messages
|
||||
|
||||
# 发送请求
|
||||
resp = await self._req(args)
|
||||
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp)
|
||||
|
||||
return message
|
||||
56
pkg/provider/modelmgr/apis/moonshotchatcmpl.py
Normal file
56
pkg/provider/modelmgr/apis/moonshotchatcmpl.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ....core import app
|
||||
|
||||
from . import chatcmpl
|
||||
from .. import api, entities, errors
|
||||
from ....core import entities as core_entities, app
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
|
||||
|
||||
@api.requester_class("moonshot-chat-completions")
|
||||
class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
"""Moonshot ChatCompletion API 请求器"""
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.requester_cfg = ap.provider_cfg.data['requester']['moonshot-chat-completions']
|
||||
self.ap = ap
|
||||
|
||||
async def _closure(
|
||||
self,
|
||||
req_messages: list[dict],
|
||||
use_model: entities.LLMModelInfo,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = self.requester_cfg['args'].copy()
|
||||
args["model"] = use_model.name if use_model.model_name is None else use_model.model_name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args["tools"] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages
|
||||
|
||||
# deepseek 不支持多模态,把content都转换成纯文字
|
||||
for m in messages:
|
||||
if 'content' in m and isinstance(m["content"], list):
|
||||
m["content"] = " ".join([c["text"] for c in m["content"]])
|
||||
|
||||
# 删除空的
|
||||
messages = [m for m in messages if m["content"].strip() != ""]
|
||||
|
||||
args["messages"] = messages
|
||||
|
||||
# 发送请求
|
||||
resp = await self._req(args)
|
||||
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp)
|
||||
|
||||
return message
|
||||
@@ -5,7 +5,7 @@ import typing
|
||||
import pydantic
|
||||
|
||||
from . import api
|
||||
from . import token, tokenizer
|
||||
from . import token
|
||||
|
||||
|
||||
class LLMModelInfo(pydantic.BaseModel):
|
||||
@@ -19,11 +19,9 @@ class LLMModelInfo(pydantic.BaseModel):
|
||||
|
||||
requester: api.LLMAPIRequester
|
||||
|
||||
tokenizer: 'tokenizer.LLMTokenizer'
|
||||
|
||||
tool_call_supported: typing.Optional[bool] = False
|
||||
|
||||
max_tokens: typing.Optional[int] = 2048
|
||||
vision_supported: typing.Optional[bool] = False
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user