mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 11:29:39 +08:00
Compare commits
131 Commits
v4.0.3
...
v4.3.0.bet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4012310d99 | ||
|
|
9e9bc88473 | ||
|
|
53ade384eb | ||
|
|
8b2480ad3b | ||
|
|
b176959836 | ||
|
|
a0c42a5f6e | ||
|
|
17d997c88e | ||
|
|
0ea7609ff1 | ||
|
|
28d4b1dd61 | ||
|
|
5179b3e53a | ||
|
|
288b294148 | ||
|
|
b464d238c5 | ||
|
|
e1a78e8ff9 | ||
|
|
2b8eb5f01c | ||
|
|
bf2bc70794 | ||
|
|
ebe0b68e8f | ||
|
|
39c50d3c12 | ||
|
|
621f1301b3 | ||
|
|
0b60ef0d06 | ||
|
|
41650b585a | ||
|
|
f5b893cfe0 | ||
|
|
e0abd19636 | ||
|
|
4380041c7f | ||
|
|
65814a4644 | ||
|
|
7237294008 | ||
|
|
214bc8ada9 | ||
|
|
6a1de889b4 | ||
|
|
4a319b2b20 | ||
|
|
9f269d1614 | ||
|
|
4b57771eb1 | ||
|
|
5922be7e15 | ||
|
|
10a44c70b6 | ||
|
|
5b044a1917 | ||
|
|
a60aa6f644 | ||
|
|
1a10b40b17 | ||
|
|
e2124054bf | ||
|
|
ee3da8aa17 | ||
|
|
c246470b37 | ||
|
|
f474e42b79 | ||
|
|
5553a86ac8 | ||
|
|
01613b2f0d | ||
|
|
a177786063 | ||
|
|
62b2884011 | ||
|
|
6b782f8761 | ||
|
|
0c2560cafb | ||
|
|
c5eeab2fd0 | ||
|
|
6f2fd72af6 | ||
|
|
2d06f1cadb | ||
|
|
af493c117c | ||
|
|
896fef8cce | ||
|
|
89c1972abe | ||
|
|
1627d04958 | ||
|
|
c959c99e45 | ||
|
|
0203faa8c1 | ||
|
|
35f76cb7ae | ||
|
|
c34232a26c | ||
|
|
b43dd95dc6 | ||
|
|
5331ba83d7 | ||
|
|
a2038b86f1 | ||
|
|
eb066f3485 | ||
|
|
bf98b82cf2 | ||
|
|
edd70b943d | ||
|
|
3cbc823085 | ||
|
|
48becf2c51 | ||
|
|
56c686cd5a | ||
|
|
208273c0dd | ||
|
|
2ff7ca3025 | ||
|
|
61a2361730 | ||
|
|
f80f997a89 | ||
|
|
18529a42c1 | ||
|
|
3e707b4b6e | ||
|
|
62f0a938a8 | ||
|
|
ad3a163d82 | ||
|
|
f5a4503610 | ||
|
|
ec012cf5ed | ||
|
|
d70eceb72c | ||
|
|
f271608114 | ||
|
|
793f0a9c10 | ||
|
|
4f2ec195fc | ||
|
|
e6bc009414 | ||
|
|
20dc8fb5ab | ||
|
|
9a71edfeb0 | ||
|
|
fe3fd664af | ||
|
|
6402755ac6 | ||
|
|
ac8fe049de | ||
|
|
955b391253 | ||
|
|
08c6672841 | ||
|
|
8917050fae | ||
|
|
21daef46f7 | ||
|
|
8ad60b5b64 | ||
|
|
7e17c96c30 | ||
|
|
f17b06767e | ||
|
|
70a29fc623 | ||
|
|
239223be3f | ||
|
|
b112cb320c | ||
|
|
5aaf2ba3ef | ||
|
|
f1e9f46af1 | ||
|
|
8dfef1d118 | ||
|
|
919a621bf8 | ||
|
|
3ac96f464d | ||
|
|
f9f03b81d1 | ||
|
|
42171a9c07 | ||
|
|
f1f00115c9 | ||
|
|
59bff61409 | ||
|
|
778693a804 | ||
|
|
e5b2da225c | ||
|
|
4a988b89a2 | ||
|
|
e5e8807312 | ||
|
|
1376530c2e | ||
|
|
7d34a2154b | ||
|
|
ff335130ae | ||
|
|
0afef0ac0f | ||
|
|
6447f270ea | ||
|
|
81be62e1a4 | ||
|
|
409909ccb1 | ||
|
|
b821b69dbb | ||
|
|
7e2448655e | ||
|
|
a7d2a68639 | ||
|
|
aba51409a7 | ||
|
|
5e5d37cbf1 | ||
|
|
e5a99a0fe4 | ||
|
|
a594cc07f6 | ||
|
|
0a9714fbe7 | ||
|
|
1992934dce | ||
|
|
bb930aec14 | ||
|
|
1d7f2ab701 | ||
|
|
347da6142e | ||
|
|
a9f4dc517a | ||
|
|
9d45f3f3a7 | ||
|
|
256d24718b | ||
|
|
1272b8ef16 |
10
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
10
.github/ISSUE_TEMPLATE/bug-report.yml
vendored
@@ -1,5 +1,5 @@
|
||||
name: 漏洞反馈
|
||||
description: 报错或漏洞请使用这个模板创建,不使用此模板创建的异常、漏洞相关issue将被直接关闭。由于自己操作不当/不甚了解所用技术栈引起的网络连接问题恕无法解决,请勿提 issue。容器间网络连接问题,参考文档 https://docs.langbot.app/deploy/network-details.html
|
||||
description: 【供中文用户】报错或漏洞请使用这个模板创建,不使用此模板创建的异常、漏洞相关issue将被直接关闭。由于自己操作不当/不甚了解所用技术栈引起的网络连接问题恕无法解决,请勿提 issue。容器间网络连接问题,参考文档 https://docs.langbot.app/zh/workshop/network-details.html
|
||||
title: "[Bug]: "
|
||||
labels: ["bug?"]
|
||||
body:
|
||||
@@ -7,7 +7,7 @@ body:
|
||||
attributes:
|
||||
label: 运行环境
|
||||
description: LangBot 版本、操作系统、系统架构、**Python版本**、**主机地理位置**
|
||||
placeholder: 例如:v3.3.0、CentOS x64 Python 3.10.3、Docker 的系统直接写 Docker 就行
|
||||
placeholder: 例如:v3.3.0、CentOS x64 Python 3.10.3、Docker
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
@@ -19,12 +19,12 @@ body:
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 复现步骤
|
||||
description: 如何重现这个问题,越详细越好;请贴上所有相关的配置文件和元数据文件(注意隐去敏感信息)
|
||||
description: 提供越多信息,我们会越快解决问题,建议多提供配置截图;**如果你不认真填写(只一两句话概括),我们会很生气并且立即关闭 issue 或两年后才回复你**
|
||||
validations:
|
||||
required: true
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: 启用的插件
|
||||
description: 有些情况可能和插件功能有关,建议提供插件启用情况。可以使用`!plugin`命令查看已启用的插件
|
||||
description: 有些情况可能和插件功能有关,建议提供插件启用情况。
|
||||
validations:
|
||||
required: false
|
||||
|
||||
30
.github/ISSUE_TEMPLATE/bug-report_en.yml
vendored
Normal file
30
.github/ISSUE_TEMPLATE/bug-report_en.yml
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
name: Bug report
|
||||
description: Report bugs or vulnerabilities using this template. For container network connection issues, refer to the documentation https://docs.langbot.app/en/workshop/network-details.html
|
||||
title: "[Bug]: "
|
||||
labels: ["bug?"]
|
||||
body:
|
||||
- type: input
|
||||
attributes:
|
||||
label: Runtime environment
|
||||
description: LangBot version, operating system, system architecture, **Python version**, **host location**
|
||||
placeholder: "For example: v3.3.0, CentOS x64 Python 3.10.3, Docker"
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Exception
|
||||
description: Describe the exception in detail, what happened and when it happened. **Please include log information.**
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Reproduction steps
|
||||
description: How to reproduce this problem, the more detailed the better; the more information you provide, the faster we will solve the problem. 【注意】请务必认真填写此部分,若不提供完整信息(如只有一两句话的概括),我们将不会回复!
|
||||
validations:
|
||||
required: false
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Enabled plugins
|
||||
description: Some cases may be related to plugin functionality, so please provide the plugin enablement status.
|
||||
validations:
|
||||
required: false
|
||||
4
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
4
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
@@ -1,7 +1,7 @@
|
||||
name: 需求建议
|
||||
title: "[Feature]: "
|
||||
labels: ["改进"]
|
||||
description: "新功能或现有功能优化请使用这个模板;不符合类别的issue将被直接关闭"
|
||||
labels: []
|
||||
description: "【供中文用户】新功能或现有功能优化请使用这个模板;不符合类别的issue将被直接关闭"
|
||||
body:
|
||||
- type: dropdown
|
||||
attributes:
|
||||
|
||||
21
.github/ISSUE_TEMPLATE/feature-request_en.yml
vendored
Normal file
21
.github/ISSUE_TEMPLATE/feature-request_en.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
name: Feature request
|
||||
title: "[Feature]: "
|
||||
labels: []
|
||||
description: "New features or existing feature improvements should use this template; issues that do not match will be closed directly"
|
||||
body:
|
||||
- type: dropdown
|
||||
attributes:
|
||||
label: This is a?
|
||||
description: New feature request or existing feature improvement
|
||||
options:
|
||||
- New feature
|
||||
- Existing feature improvement
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Detailed description
|
||||
description: Detailed description, the more detailed the better
|
||||
validations:
|
||||
required: true
|
||||
|
||||
2
.github/ISSUE_TEMPLATE/submit-plugin.yml
vendored
2
.github/ISSUE_TEMPLATE/submit-plugin.yml
vendored
@@ -1,7 +1,7 @@
|
||||
name: 提交新插件
|
||||
title: "[Plugin]: 请求登记新插件"
|
||||
labels: ["独立插件"]
|
||||
description: "本模板供且仅供提交新插件使用"
|
||||
description: "【供中文用户】本模板供且仅供提交新插件使用"
|
||||
body:
|
||||
- type: input
|
||||
attributes:
|
||||
|
||||
24
.github/ISSUE_TEMPLATE/submit-plugin_en.yml
vendored
Normal file
24
.github/ISSUE_TEMPLATE/submit-plugin_en.yml
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
name: Submit a new plugin
|
||||
title: "[Plugin]: Request to register a new plugin"
|
||||
labels: ["Independent Plugin"]
|
||||
description: "This template is only for submitting new plugins"
|
||||
body:
|
||||
- type: input
|
||||
attributes:
|
||||
label: Plugin name
|
||||
description: Fill in the name of the plugin
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Plugin code repository address
|
||||
description: Only support Github
|
||||
validations:
|
||||
required: true
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Plugin description
|
||||
description: The description of the plugin
|
||||
validations:
|
||||
required: true
|
||||
|
||||
27
.github/pull_request_template.md
vendored
27
.github/pull_request_template.md
vendored
@@ -1,20 +1,21 @@
|
||||
## 概述
|
||||
## 概述 / Overview
|
||||
|
||||
实现/解决/优化的内容:
|
||||
> 请在此部分填写你实现/解决/优化的内容:
|
||||
> Summary of what you implemented/solved/optimized:
|
||||
|
||||
## 检查清单
|
||||
## 检查清单 / Checklist
|
||||
|
||||
### PR 作者完成
|
||||
### PR 作者完成 / For PR author
|
||||
|
||||
*请在方括号间写`x`以打勾
|
||||
*请在方括号间写`x`以打勾 / Please tick the box with `x`*
|
||||
|
||||
- [ ] 阅读仓库[贡献指引](https://github.com/RockChinQ/LangBot/blob/master/CONTRIBUTING.md)了吗?
|
||||
- [ ] 与项目所有者沟通过了吗?
|
||||
- [ ] 我确定已自行测试所作的更改,确保功能符合预期。
|
||||
- [ ] 阅读仓库[贡献指引](https://github.com/RockChinQ/LangBot/blob/master/CONTRIBUTING.md)了吗? / Have you read the [contribution guide](https://github.com/RockChinQ/LangBot/blob/master/CONTRIBUTING.md)?
|
||||
- [ ] 与项目所有者沟通过了吗? / Have you communicated with the project maintainer?
|
||||
- [ ] 我确定已自行测试所作的更改,确保功能符合预期。 / I have tested the changes and ensured they work as expected.
|
||||
|
||||
### 项目所有者完成
|
||||
### 项目维护者完成 / For project maintainer
|
||||
|
||||
- [ ] 相关 issues 链接了吗?
|
||||
- [ ] 配置项写好了吗?迁移写好了吗?生效了吗?
|
||||
- [ ] 依赖加到 pyproject.toml 和 core/bootutils/deps.py 了吗
|
||||
- [ ] 文档编写了吗?
|
||||
- [ ] 相关 issues 链接了吗? / Have you linked the related issues?
|
||||
- [ ] 配置项写好了吗?迁移写好了吗?生效了吗? / Have you written the configuration items? Have you written the migration? Has it taken effect?
|
||||
- [ ] 依赖加到 pyproject.toml 和 core/bootutils/deps.py 了吗 / Have you added the dependencies to pyproject.toml and core/bootutils/deps.py?
|
||||
- [ ] 文档编写了吗? / Have you written the documentation?
|
||||
@@ -1 +0,0 @@
|
||||
3.12
|
||||
@@ -5,22 +5,27 @@
|
||||
### 贡献形式
|
||||
|
||||
- 提交PR,解决issues中提到的bug或期待的功能
|
||||
- 提交PR,实现您设想的功能(请先提出issue与作者沟通)
|
||||
- 优化代码架构,使各个模块的组织更加整洁优雅
|
||||
- 在issues中提出发现的bug或者期待的功能
|
||||
- 提交PR,实现您设想的功能(请先提出issue与项目维护者沟通)
|
||||
- 为本项目在其他社交平台撰写文章、制作视频等
|
||||
- 为本项目的衍生项目作出贡献,或开发插件增加功能
|
||||
|
||||
### 如何开始
|
||||
### 沟通语言规范
|
||||
|
||||
- 加入本项目交流群,一同探讨项目相关事务
|
||||
- 解决本项目或衍生项目的issues中亟待解决的问题
|
||||
- 阅读并完善本项目文档
|
||||
- 在各个社交媒体撰写本项目教程等
|
||||
- 在 PR 和 Commit Message 中请使用全英文
|
||||
- 对于中文用户,issue 中可以使用中文
|
||||
|
||||
### 代码规范
|
||||
<hr/>
|
||||
|
||||
- 代码中的注解`务必`符合Google风格的规范
|
||||
- 模块顶部的引入代码请遵循`系统模块`、`第三方库模块`、`自定义模块`的顺序进行引入
|
||||
- `不要`直接引入模块的特定属性,而是引入这个模块,再通过`xxx.yyy`的形式使用属性
|
||||
- 任何作用域的字段`必须`先声明后使用,并在声明处注明类型提示
|
||||
## Guidelines
|
||||
|
||||
### Contribution
|
||||
|
||||
- Submit PRs to solve bugs or features in the issues
|
||||
- Submit PRs to implement your ideas (Please create an issue first and communicate with the project maintainer)
|
||||
- Write articles or make videos about this project on other social platforms
|
||||
- Contribute to the development of derivative projects, or develop plugins to add features
|
||||
|
||||
### Spoken Language
|
||||
|
||||
- Use English in PRs and Commit Messages
|
||||
- For English users, you can use English in issues
|
||||
|
||||
@@ -6,7 +6,7 @@ COPY web ./web
|
||||
|
||||
RUN cd web && npm install && npm run build
|
||||
|
||||
FROM python:3.10.13-slim
|
||||
FROM python:3.12.7-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
|
||||
@@ -26,7 +26,7 @@
|
||||
<img src="https://img.shields.io/badge/python-3.10 ~ 3.13 -blue.svg" alt="python">
|
||||
[](https://gitcode.com/RockChinQ/LangBot)
|
||||
|
||||
[简体中文](README.md) / [English](README_EN.md) / [日本語](README_JP.md) / (PR for your language)
|
||||
简体中文 / [English](README_EN.md) / [日本語](README_JP.md) / (PR for your language)
|
||||
|
||||
</div>
|
||||
|
||||
@@ -120,6 +120,7 @@ docker compose up -d
|
||||
| [xAI](https://x.ai/) | ✅ | |
|
||||
| [智谱AI](https://open.bigmodel.cn/) | ✅ | |
|
||||
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | ✅ | 大模型和 GPU 资源平台 |
|
||||
| [302 AI](https://share.302.ai/SuTG99) | ✅ | 大模型聚合平台 |
|
||||
| [Google Gemini](https://aistudio.google.com/prompts/new_chat) | ✅ | |
|
||||
| [Dify](https://dify.ai) | ✅ | LLMOps 平台 |
|
||||
| [Ollama](https://ollama.com/) | ✅ | 本地大模型运行平台 |
|
||||
@@ -152,3 +153,9 @@ docker compose up -d
|
||||
<a href="https://github.com/RockChinQ/LangBot/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=RockChinQ/LangBot" />
|
||||
</a>
|
||||
|
||||
## 😎 保持更新
|
||||
|
||||
点击仓库右上角 Star 和 Watch 按钮,获取最新动态。
|
||||
|
||||

|
||||
|
||||
@@ -24,7 +24,7 @@
|
||||
[](https://github.com/RockChinQ/LangBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.10 ~ 3.13 -blue.svg" alt="python">
|
||||
|
||||
[简体中文](README.md) / [English](README_EN.md) / [日本語](README_JP.md) / (PR for your language)
|
||||
[简体中文](README.md) / English / [日本語](README_JP.md) / (PR for your language)
|
||||
|
||||
</div>
|
||||
|
||||
@@ -118,6 +118,7 @@ Directly use the released version to run, see the [Manual Deployment](https://do
|
||||
| [Zhipu AI](https://open.bigmodel.cn/) | ✅ | |
|
||||
| [Dify](https://dify.ai) | ✅ | LLMOps platform |
|
||||
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | ✅ | LLM and GPU resource platform |
|
||||
| [302 AI](https://share.302.ai/SuTG99) | ✅ | LLM gateway(MaaS) |
|
||||
| [Google Gemini](https://aistudio.google.com/prompts/new_chat) | ✅ | |
|
||||
| [Ollama](https://ollama.com/) | ✅ | Local LLM running platform |
|
||||
| [LMStudio](https://lmstudio.ai/) | ✅ | Local LLM running platform |
|
||||
@@ -135,3 +136,9 @@ Thank you for the following [code contributors](https://github.com/RockChinQ/Lan
|
||||
<a href="https://github.com/RockChinQ/LangBot/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=RockChinQ/LangBot" />
|
||||
</a>
|
||||
|
||||
## 😎 Stay Ahead
|
||||
|
||||
Click the Star and Watch button in the upper right corner of the repository to get the latest updates.
|
||||
|
||||

|
||||
@@ -23,7 +23,7 @@
|
||||
[](https://github.com/RockChinQ/LangBot/releases/latest)
|
||||
<img src="https://img.shields.io/badge/python-3.10 ~ 3.13 -blue.svg" alt="python">
|
||||
|
||||
[简体中文](README.md) / [English](README_EN.md) / [日本語](README_JP.md) / (PR for your language)
|
||||
[简体中文](README_CN.md) / [English](README.md) / [日本語](README_JP.md) / (PR for your language)
|
||||
|
||||
</div>
|
||||
|
||||
@@ -116,6 +116,7 @@ LangBotはBTPanelにリストされています。BTPanelをインストール
|
||||
| [xAI](https://x.ai/) | ✅ | |
|
||||
| [Zhipu AI](https://open.bigmodel.cn/) | ✅ | |
|
||||
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | ✅ | 大模型とGPUリソースプラットフォーム |
|
||||
| [302 AI](https://share.302.ai/SuTG99) | ✅ | LLMゲートウェイ(MaaS) |
|
||||
| [Google Gemini](https://aistudio.google.com/prompts/new_chat) | ✅ | |
|
||||
| [Dify](https://dify.ai) | ✅ | LLMOpsプラットフォーム |
|
||||
| [Ollama](https://ollama.com/) | ✅ | ローカルLLM実行プラットフォーム |
|
||||
@@ -134,3 +135,9 @@ LangBot への貢献に対して、以下の [コード貢献者](https://github
|
||||
<a href="https://github.com/RockChinQ/LangBot/graphs/contributors">
|
||||
<img src="https://contrib.rocks/image?repo=RockChinQ/LangBot" />
|
||||
</a>
|
||||
|
||||
## 😎 最新情報を入手
|
||||
|
||||
リポジトリの右上にある Star と Watch ボタンをクリックして、最新の更新を取得してください。
|
||||
|
||||

|
||||
@@ -9,7 +9,6 @@ spec:
|
||||
components:
|
||||
ComponentTemplate:
|
||||
fromFiles:
|
||||
- pkg/platform/adapter.yaml
|
||||
- pkg/provider/modelmgr/requester.yaml
|
||||
MessagePlatformAdapter:
|
||||
fromDirs:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from v1 import client
|
||||
from v1 import client # type: ignore
|
||||
|
||||
import asyncio
|
||||
|
||||
@@ -8,19 +8,13 @@ import json
|
||||
|
||||
class TestDifyClient:
|
||||
async def test_chat_messages(self):
|
||||
cln = client.AsyncDifyServiceClient(
|
||||
api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL')
|
||||
)
|
||||
cln = client.AsyncDifyServiceClient(api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL'))
|
||||
|
||||
async for chunk in cln.chat_messages(
|
||||
inputs={}, query='调用工具查看现在几点?', user='test'
|
||||
):
|
||||
async for chunk in cln.chat_messages(inputs={}, query='调用工具查看现在几点?', user='test'):
|
||||
print(json.dumps(chunk, ensure_ascii=False, indent=4))
|
||||
|
||||
async def test_upload_file(self):
|
||||
cln = client.AsyncDifyServiceClient(
|
||||
api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL')
|
||||
)
|
||||
cln = client.AsyncDifyServiceClient(api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL'))
|
||||
|
||||
file_bytes = open('img.png', 'rb').read()
|
||||
|
||||
@@ -32,9 +26,7 @@ class TestDifyClient:
|
||||
print(json.dumps(resp, ensure_ascii=False, indent=4))
|
||||
|
||||
async def test_workflow_run(self):
|
||||
cln = client.AsyncDifyServiceClient(
|
||||
api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL')
|
||||
)
|
||||
cln = client.AsyncDifyServiceClient(api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL'))
|
||||
|
||||
# resp = await cln.workflow_run(inputs={}, user="test")
|
||||
# # print(json.dumps(resp, ensure_ascii=False, indent=4))
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import asyncio
|
||||
import dingtalk_stream
|
||||
import dingtalk_stream # type: ignore
|
||||
from dingtalk_stream import AckMessage
|
||||
|
||||
|
||||
@@ -27,9 +27,3 @@ class EchoTextHandler(dingtalk_stream.ChatbotHandler):
|
||||
await asyncio.sleep(0.1) # 异步等待,避免阻塞
|
||||
|
||||
return self.incoming_message
|
||||
|
||||
|
||||
async def get_dingtalk_client(client_id, client_secret):
|
||||
from api import DingTalkClient # 延迟导入,避免循环导入
|
||||
|
||||
return DingTalkClient(client_id, client_secret)
|
||||
|
||||
@@ -2,7 +2,7 @@ import base64
|
||||
import json
|
||||
import time
|
||||
from typing import Callable
|
||||
import dingtalk_stream
|
||||
import dingtalk_stream # type: ignore
|
||||
from .EchoHandler import EchoTextHandler
|
||||
from .dingtalkevent import DingTalkEvent
|
||||
import httpx
|
||||
@@ -17,6 +17,7 @@ class DingTalkClient:
|
||||
robot_name: str,
|
||||
robot_code: str,
|
||||
markdown_card: bool,
|
||||
logger: None,
|
||||
):
|
||||
"""初始化 WebSocket 连接并自动启动"""
|
||||
self.credential = dingtalk_stream.Credential(client_id, client_secret)
|
||||
@@ -34,6 +35,7 @@ class DingTalkClient:
|
||||
self.robot_code = robot_code
|
||||
self.access_token_expiry_time = ''
|
||||
self.markdown_card = markdown_card
|
||||
self.logger = logger
|
||||
|
||||
async def get_access_token(self):
|
||||
url = 'https://api.dingtalk.com/v1.0/oauth2/accessToken'
|
||||
@@ -47,8 +49,8 @@ class DingTalkClient:
|
||||
self.access_token = response_data.get('accessToken')
|
||||
expires_in = int(response_data.get('expireIn', 7200))
|
||||
self.access_token_expiry_time = time.time() + expires_in - 60
|
||||
except Exception as e:
|
||||
raise Exception(e)
|
||||
except Exception:
|
||||
await self.logger.error('failed to get access token in dingtalk')
|
||||
|
||||
async def is_token_expired(self):
|
||||
"""检查token是否过期"""
|
||||
@@ -73,7 +75,7 @@ class DingTalkClient:
|
||||
result = response.json()
|
||||
download_url = result.get('downloadUrl')
|
||||
else:
|
||||
raise Exception(f'Error: {response.status_code}, {response.text}')
|
||||
await self.logger.error(f'failed to get download url: {response.json()}')
|
||||
|
||||
if download_url:
|
||||
return await self.download_url_to_base64(download_url)
|
||||
@@ -84,10 +86,11 @@ class DingTalkClient:
|
||||
|
||||
if response.status_code == 200:
|
||||
file_bytes = response.content
|
||||
base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式
|
||||
return base64_str
|
||||
mime_type = response.headers.get('Content-Type', 'application/octet-stream')
|
||||
base64_str = base64.b64encode(file_bytes).decode('utf-8')
|
||||
return f'data:{mime_type};base64,{base64_str}'
|
||||
else:
|
||||
raise Exception('获取文件失败')
|
||||
await self.logger.error(f'failed to get files: {response.json()}')
|
||||
|
||||
async def get_audio_url(self, download_code: str):
|
||||
if not await self.check_access_token():
|
||||
@@ -103,7 +106,7 @@ class DingTalkClient:
|
||||
if download_url:
|
||||
return await self.download_url_to_base64(download_url)
|
||||
else:
|
||||
raise Exception('获取音频失败')
|
||||
await self.logger.error(f'failed to get audio: {response.json()}')
|
||||
else:
|
||||
raise Exception(f'Error: {response.status_code}, {response.text}')
|
||||
|
||||
@@ -115,13 +118,20 @@ class DingTalkClient:
|
||||
if event:
|
||||
await self._handle_message(event)
|
||||
|
||||
async def send_message(self, content: str, incoming_message):
|
||||
async def send_message(self, content: str, incoming_message, at: bool):
|
||||
if self.markdown_card:
|
||||
self.EchoTextHandler.reply_markdown(
|
||||
title=self.robot_name + '的回答',
|
||||
text=content,
|
||||
incoming_message=incoming_message,
|
||||
)
|
||||
if at:
|
||||
self.EchoTextHandler.reply_markdown(
|
||||
title='@' + incoming_message.sender_nick + ' ' + content,
|
||||
text='@' + incoming_message.sender_nick + ' ' + content,
|
||||
incoming_message=incoming_message,
|
||||
)
|
||||
else:
|
||||
self.EchoTextHandler.reply_markdown(
|
||||
title=content,
|
||||
text=content,
|
||||
incoming_message=incoming_message,
|
||||
)
|
||||
else:
|
||||
self.EchoTextHandler.reply_text(content, incoming_message)
|
||||
|
||||
@@ -184,7 +194,10 @@ class DingTalkClient:
|
||||
del copy_message_data['IncomingMessage']
|
||||
# print("message_data:", json.dumps(copy_message_data, indent=4, ensure_ascii=False))
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
if self.logger:
|
||||
await self.logger.error(f'Error in get_message: {traceback.format_exc()}')
|
||||
else:
|
||||
traceback.print_exc()
|
||||
|
||||
return message_data
|
||||
|
||||
@@ -207,9 +220,12 @@ class DingTalkClient:
|
||||
}
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
await client.post(url, headers=headers, json=data)
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
if response.status_code == 200:
|
||||
return
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
await self.logger.error(f'failed to send proactive massage to person: {traceback.format_exc()}')
|
||||
raise Exception(f'failed to send proactive massage to person: {traceback.format_exc()}')
|
||||
|
||||
async def send_proactive_message_to_group(self, target_id: str, content: str):
|
||||
if not await self.check_access_token():
|
||||
@@ -230,9 +246,12 @@ class DingTalkClient:
|
||||
}
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
await client.post(url, headers=headers, json=data)
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
if response.status_code == 200:
|
||||
return
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
await self.logger.error(f'failed to send proactive massage to group: {traceback.format_exc()}')
|
||||
raise Exception(f'failed to send proactive massage to group: {traceback.format_exc()}')
|
||||
|
||||
async def start(self):
|
||||
"""启动 WebSocket 连接,监听消息"""
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Dict, Any, Optional
|
||||
import dingtalk_stream
|
||||
import dingtalk_stream # type: ignore
|
||||
|
||||
|
||||
class DingTalkEvent(dict):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# 微信公众号的加解密算法与企业微信一样,所以直接使用企业微信的加解密算法文件
|
||||
import time
|
||||
import traceback
|
||||
from ..wecom_api.WXBizMsgCrypt3 import WXBizMsgCrypt
|
||||
from libs.wecom_api.WXBizMsgCrypt3 import WXBizMsgCrypt
|
||||
import xml.etree.ElementTree as ET
|
||||
from quart import Quart, request
|
||||
import hashlib
|
||||
@@ -23,7 +23,7 @@ xml_template = """
|
||||
|
||||
|
||||
class OAClient:
|
||||
def __init__(self, token: str, EncodingAESKey: str, AppID: str, Appsecret: str):
|
||||
def __init__(self, token: str, EncodingAESKey: str, AppID: str, Appsecret: str, logger: None):
|
||||
self.token = token
|
||||
self.aes = EncodingAESKey
|
||||
self.appid = AppID
|
||||
@@ -43,6 +43,7 @@ class OAClient:
|
||||
self.access_token_expiry_time = None
|
||||
self.msg_id_map = {}
|
||||
self.generated_content = {}
|
||||
self.logger = logger
|
||||
|
||||
async def handle_callback_request(self):
|
||||
try:
|
||||
@@ -54,6 +55,7 @@ class OAClient:
|
||||
echostr = request.args.get('echostr', '')
|
||||
msg_signature = request.args.get('msg_signature', '')
|
||||
if msg_signature is None:
|
||||
await self.logger.error('msg_signature不在请求体中')
|
||||
raise Exception('msg_signature不在请求体中')
|
||||
|
||||
if request.method == 'GET':
|
||||
@@ -64,6 +66,7 @@ class OAClient:
|
||||
if check_signature == signature:
|
||||
return echostr # 验证成功返回echostr
|
||||
else:
|
||||
await self.logger.error('拒绝请求')
|
||||
raise Exception('拒绝请求')
|
||||
elif request.method == 'POST':
|
||||
encryt_msg = await request.data
|
||||
@@ -72,6 +75,7 @@ class OAClient:
|
||||
xml_msg = xml_msg.decode('utf-8')
|
||||
|
||||
if ret != 0:
|
||||
await self.logger.error('消息解密失败')
|
||||
raise Exception('消息解密失败')
|
||||
|
||||
message_data = await self.get_message(xml_msg)
|
||||
@@ -114,6 +118,7 @@ class OAClient:
|
||||
return ''
|
||||
|
||||
except Exception:
|
||||
await self.logger.error(f'handle_callback_request失败: {traceback.format_exc()}')
|
||||
traceback.print_exc()
|
||||
|
||||
async def get_message(self, xml_msg: str):
|
||||
@@ -176,6 +181,7 @@ class OAClientForLongerResponse:
|
||||
AppID: str,
|
||||
Appsecret: str,
|
||||
LoadingMessage: str,
|
||||
logger: None,
|
||||
):
|
||||
self.token = token
|
||||
self.aes = EncodingAESKey
|
||||
@@ -197,6 +203,7 @@ class OAClientForLongerResponse:
|
||||
self.loading_message = LoadingMessage
|
||||
self.msg_queue = {}
|
||||
self.user_msg_queue = {}
|
||||
self.logger = logger
|
||||
|
||||
async def handle_callback_request(self):
|
||||
try:
|
||||
@@ -207,6 +214,7 @@ class OAClientForLongerResponse:
|
||||
msg_signature = request.args.get('msg_signature', '')
|
||||
|
||||
if msg_signature is None:
|
||||
await self.logger.error('msg_signature不在请求体中')
|
||||
raise Exception('msg_signature不在请求体中')
|
||||
|
||||
if request.method == 'GET':
|
||||
@@ -221,6 +229,7 @@ class OAClientForLongerResponse:
|
||||
xml_msg = xml_msg.decode('utf-8')
|
||||
|
||||
if ret != 0:
|
||||
await self.logger.error('消息解密失败')
|
||||
raise Exception('消息解密失败')
|
||||
|
||||
# 解析 XML
|
||||
@@ -270,6 +279,7 @@ class OAClientForLongerResponse:
|
||||
return response_xml
|
||||
|
||||
except Exception:
|
||||
await self.logger.error(f'handle_callback_request失败: {traceback.format_exc()}')
|
||||
traceback.print_exc()
|
||||
|
||||
async def get_message(self, xml_msg: str):
|
||||
|
||||
@@ -3,7 +3,7 @@ from quart import request
|
||||
import httpx
|
||||
from quart import Quart
|
||||
from typing import Callable, Dict, Any
|
||||
from pkg.platform.types import events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
from .qqofficialevent import QQOfficialEvent
|
||||
import json
|
||||
import traceback
|
||||
@@ -34,7 +34,7 @@ def handle_validation(body: dict, bot_secret: str):
|
||||
|
||||
|
||||
class QQOfficialClient:
|
||||
def __init__(self, secret: str, token: str, app_id: str):
|
||||
def __init__(self, secret: str, token: str, app_id: str, logger: None):
|
||||
self.app = Quart(__name__)
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
@@ -49,6 +49,7 @@ class QQOfficialClient:
|
||||
self.base_url = 'https://api.sgroup.qq.com'
|
||||
self.access_token = ''
|
||||
self.access_token_expiry_time = None
|
||||
self.logger = logger
|
||||
|
||||
async def check_access_token(self):
|
||||
"""检查access_token是否存在"""
|
||||
@@ -77,6 +78,7 @@ class QQOfficialClient:
|
||||
if access_token:
|
||||
self.access_token = access_token
|
||||
except Exception as e:
|
||||
await self.logger.error(f'获取access_token失败: {response_data}')
|
||||
raise Exception(f'获取access_token失败: {e}')
|
||||
|
||||
async def handle_callback_request(self):
|
||||
@@ -102,7 +104,7 @@ class QQOfficialClient:
|
||||
return {'code': 0, 'message': 'success'}
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
|
||||
return {'error': str(e)}, 400
|
||||
|
||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||
@@ -178,9 +180,11 @@ class QQOfficialClient:
|
||||
'msg_id': msg_id,
|
||||
}
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
response_data = response.json()
|
||||
if response.status_code == 200:
|
||||
return
|
||||
else:
|
||||
await self.logger.error(f'发送私聊消息失败: {response_data}')
|
||||
raise ValueError(response)
|
||||
|
||||
async def send_group_text_msg(self, group_openid: str, content: str, msg_id: str):
|
||||
@@ -203,6 +207,7 @@ class QQOfficialClient:
|
||||
if response.status_code == 200:
|
||||
return
|
||||
else:
|
||||
await self.logger.error(f'发送群聊消息失败:{response.json()}')
|
||||
raise Exception(response.read().decode())
|
||||
|
||||
async def send_channle_group_text_msg(self, channel_id: str, content: str, msg_id: str):
|
||||
@@ -225,6 +230,7 @@ class QQOfficialClient:
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
await self.logger.error(f'发送频道群聊消息失败: {response.json()}')
|
||||
raise Exception(response)
|
||||
|
||||
async def send_channle_private_text_msg(self, guild_id: str, content: str, msg_id: str):
|
||||
@@ -247,6 +253,7 @@ class QQOfficialClient:
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
await self.logger.error(f'发送频道私聊消息失败: {response.json()}')
|
||||
raise Exception(response)
|
||||
|
||||
async def is_token_expired(self):
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
import json
|
||||
import traceback
|
||||
from quart import Quart, jsonify, request
|
||||
from slack_sdk.web.async_client import AsyncWebClient
|
||||
from .slackevent import SlackEvent
|
||||
from typing import Callable
|
||||
from pkg.platform.types import events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
|
||||
|
||||
class SlackClient:
|
||||
def __init__(self, bot_token: str, signing_secret: str):
|
||||
def __init__(self, bot_token: str, signing_secret: str, logger: None):
|
||||
self.bot_token = bot_token
|
||||
self.signing_secret = signing_secret
|
||||
self.app = Quart(__name__)
|
||||
@@ -19,6 +20,7 @@ class SlackClient:
|
||||
'example': [],
|
||||
}
|
||||
self.bot_user_id = None # 避免机器人回复自己的消息
|
||||
self.logger = logger
|
||||
|
||||
async def handle_callback_request(self):
|
||||
try:
|
||||
@@ -49,6 +51,7 @@ class SlackClient:
|
||||
return jsonify({'status': 'ok'})
|
||||
|
||||
except Exception as e:
|
||||
await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
|
||||
raise (e)
|
||||
|
||||
async def _handle_message(self, event: SlackEvent):
|
||||
@@ -78,6 +81,7 @@ class SlackClient:
|
||||
self.bot_user_id = response['message']['bot_id']
|
||||
return
|
||||
except Exception as e:
|
||||
await self.logger.error(f'Error in send_message: {e}')
|
||||
raise e
|
||||
|
||||
async def send_message_to_one(self, text: str, user_id: str):
|
||||
@@ -88,6 +92,7 @@ class SlackClient:
|
||||
|
||||
return
|
||||
except Exception as e:
|
||||
await self.logger.error(f'Error in send_message: {traceback.format_exc()}')
|
||||
raise e
|
||||
|
||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||
|
||||
@@ -1 +1,4 @@
|
||||
from .client import WeChatPadClient
|
||||
from .client import WeChatPadClient
|
||||
|
||||
|
||||
__all__ = ['WeChatPadClient']
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from libs.wechatpad_api.util.http_util import async_request, post_json
|
||||
from libs.wechatpad_api.util.http_util import post_json
|
||||
|
||||
|
||||
class ChatRoomApi:
|
||||
@@ -7,8 +7,6 @@ class ChatRoomApi:
|
||||
self.token = token
|
||||
|
||||
def get_chatroom_member_detail(self, chatroom_name):
|
||||
params = {
|
||||
"ChatRoomName": chatroom_name
|
||||
}
|
||||
params = {'ChatRoomName': chatroom_name}
|
||||
url = self.base_url + '/group/GetChatroomMemberDetail'
|
||||
return post_json(url, token=self.token, data=params)
|
||||
|
||||
@@ -1,32 +1,23 @@
|
||||
from libs.wechatpad_api.util.http_util import async_request, post_json
|
||||
from libs.wechatpad_api.util.http_util import post_json
|
||||
import httpx
|
||||
import base64
|
||||
|
||||
|
||||
class DownloadApi:
|
||||
def __init__(self, base_url, token):
|
||||
self.base_url = base_url
|
||||
self.token = token
|
||||
|
||||
def send_download(self, aeskey, file_type, file_url):
|
||||
json_data = {
|
||||
"AesKey": aeskey,
|
||||
"FileType": file_type,
|
||||
"FileURL": file_url
|
||||
}
|
||||
url = self.base_url + "/message/SendCdnDownload"
|
||||
json_data = {'AesKey': aeskey, 'FileType': file_type, 'FileURL': file_url}
|
||||
url = self.base_url + '/message/SendCdnDownload'
|
||||
return post_json(url, token=self.token, data=json_data)
|
||||
|
||||
def get_msg_voice(self,buf_id, length, new_msgid):
|
||||
json_data = {
|
||||
"Bufid": buf_id,
|
||||
"Length": length,
|
||||
"NewMsgId": new_msgid,
|
||||
"ToUserName": ""
|
||||
}
|
||||
url = self.base_url + "/message/GetMsgVoice"
|
||||
def get_msg_voice(self, buf_id, length, new_msgid):
|
||||
json_data = {'Bufid': buf_id, 'Length': length, 'NewMsgId': new_msgid, 'ToUserName': ''}
|
||||
url = self.base_url + '/message/GetMsgVoice'
|
||||
return post_json(url, token=self.token, data=json_data)
|
||||
|
||||
|
||||
async def download_url_to_base64(self, download_url):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(download_url)
|
||||
@@ -36,4 +27,4 @@ class DownloadApi:
|
||||
base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式
|
||||
return base64_str
|
||||
else:
|
||||
raise Exception('获取文件失败')
|
||||
raise Exception('获取文件失败')
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
from libs.wechatpad_api.util.http_util import post_json,async_request
|
||||
from typing import List, Dict, Any, Optional
|
||||
|
||||
|
||||
class FriendApi:
|
||||
"""联系人API类,处理所有与联系人相关的操作"""
|
||||
|
||||
def __init__(self, base_url: str, token: str):
|
||||
self.base_url = base_url
|
||||
self.token = token
|
||||
|
||||
|
||||
@@ -1,37 +1,34 @@
|
||||
from libs.wechatpad_api.util.http_util import async_request,post_json,get_json
|
||||
from libs.wechatpad_api.util.http_util import post_json, get_json
|
||||
|
||||
|
||||
class LoginApi:
|
||||
def __init__(self, base_url: str, token: str = None, admin_key: str = None):
|
||||
'''
|
||||
"""
|
||||
|
||||
Args:
|
||||
base_url: 原始路径
|
||||
token: token
|
||||
admin_key: 管理员key
|
||||
'''
|
||||
"""
|
||||
self.base_url = base_url
|
||||
self.token = token
|
||||
# self.admin_key = admin_key
|
||||
|
||||
def get_token(self, admin_key, day: int=365):
|
||||
def get_token(self, admin_key, day: int = 365):
|
||||
# 获取普通token
|
||||
url = f"{self.base_url}/admin/GenAuthKey1"
|
||||
json_data = {
|
||||
"Count": 1,
|
||||
"Days": day
|
||||
}
|
||||
url = f'{self.base_url}/admin/GenAuthKey1'
|
||||
json_data = {'Count': 1, 'Days': day}
|
||||
return post_json(base_url=url, token=admin_key, data=json_data)
|
||||
|
||||
def get_login_qr(self, Proxy: str = ""):
|
||||
'''
|
||||
def get_login_qr(self, Proxy: str = ''):
|
||||
"""
|
||||
|
||||
Args:
|
||||
Proxy:异地使用时代理
|
||||
|
||||
Returns:json数据
|
||||
|
||||
'''
|
||||
"""
|
||||
"""
|
||||
|
||||
{
|
||||
@@ -49,54 +46,37 @@ class LoginApi:
|
||||
}
|
||||
|
||||
"""
|
||||
#获取登录二维码
|
||||
url = f"{self.base_url}/login/GetLoginQrCodeNew"
|
||||
# 获取登录二维码
|
||||
url = f'{self.base_url}/login/GetLoginQrCodeNew'
|
||||
check = False
|
||||
if Proxy != "":
|
||||
if Proxy != '':
|
||||
check = True
|
||||
json_data = {
|
||||
"Check": check,
|
||||
"Proxy": Proxy
|
||||
}
|
||||
json_data = {'Check': check, 'Proxy': Proxy}
|
||||
return post_json(base_url=url, token=self.token, data=json_data)
|
||||
|
||||
|
||||
def get_login_status(self):
|
||||
# 获取登录状态
|
||||
url = f'{self.base_url}/login/GetLoginStatus'
|
||||
return get_json(base_url=url, token=self.token)
|
||||
|
||||
|
||||
|
||||
def logout(self):
|
||||
# 退出登录
|
||||
url = f'{self.base_url}/login/LogOut'
|
||||
return post_json(base_url=url, token=self.token)
|
||||
|
||||
|
||||
|
||||
|
||||
def wake_up_login(self, Proxy: str = ""):
|
||||
def wake_up_login(self, Proxy: str = ''):
|
||||
# 唤醒登录
|
||||
url = f'{self.base_url}/login/WakeUpLogin'
|
||||
check = False
|
||||
if Proxy != "":
|
||||
if Proxy != '':
|
||||
check = True
|
||||
json_data = {
|
||||
"Check": check,
|
||||
"Proxy": ""
|
||||
}
|
||||
json_data = {'Check': check, 'Proxy': ''}
|
||||
|
||||
return post_json(base_url=url, token=self.token, data=json_data)
|
||||
|
||||
|
||||
|
||||
def login(self,admin_key):
|
||||
def login(self, admin_key):
|
||||
login_status = self.get_login_status()
|
||||
if login_status["Code"] == 300 and login_status["Text"] == "你已退出微信":
|
||||
print("token已经失效,重新获取")
|
||||
if login_status['Code'] == 300 and login_status['Text'] == '你已退出微信':
|
||||
print('token已经失效,重新获取')
|
||||
token_data = self.get_token(admin_key)
|
||||
self.token = token_data["Data"][0]
|
||||
|
||||
|
||||
|
||||
self.token = token_data['Data'][0]
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
|
||||
from libs.wechatpad_api.util.http_util import async_request, post_json
|
||||
from libs.wechatpad_api.util.http_util import post_json
|
||||
|
||||
|
||||
class MessageApi:
|
||||
@@ -7,8 +6,8 @@ class MessageApi:
|
||||
self.base_url = base_url
|
||||
self.token = token
|
||||
|
||||
def post_text(self, to_wxid, content, ats: list= []):
|
||||
'''
|
||||
def post_text(self, to_wxid, content, ats: list = []):
|
||||
"""
|
||||
|
||||
Args:
|
||||
app_id: 微信id
|
||||
@@ -18,106 +17,64 @@ class MessageApi:
|
||||
|
||||
Returns:
|
||||
|
||||
'''
|
||||
url = self.base_url + "/message/SendTextMessage"
|
||||
"""
|
||||
url = self.base_url + '/message/SendTextMessage'
|
||||
"""发送文字消息"""
|
||||
json_data = {
|
||||
"MsgItem": [
|
||||
{
|
||||
"AtWxIDList": ats,
|
||||
"ImageContent": "",
|
||||
"MsgType": 0,
|
||||
"TextContent": content,
|
||||
"ToUserName": to_wxid
|
||||
}
|
||||
]
|
||||
}
|
||||
return post_json(base_url=url, token=self.token, data=json_data)
|
||||
'MsgItem': [
|
||||
{'AtWxIDList': ats, 'ImageContent': '', 'MsgType': 0, 'TextContent': content, 'ToUserName': to_wxid}
|
||||
]
|
||||
}
|
||||
return post_json(base_url=url, token=self.token, data=json_data)
|
||||
|
||||
|
||||
|
||||
|
||||
def post_image(self, to_wxid, img_url, ats: list= []):
|
||||
def post_image(self, to_wxid, img_url, ats: list = []):
|
||||
"""发送图片消息"""
|
||||
# 这里好像可以尝试发送多个暂时未测试
|
||||
json_data = {
|
||||
"MsgItem": [
|
||||
{
|
||||
"AtWxIDList": ats,
|
||||
"ImageContent": img_url,
|
||||
"MsgType": 0,
|
||||
"TextContent": '',
|
||||
"ToUserName": to_wxid
|
||||
}
|
||||
'MsgItem': [
|
||||
{'AtWxIDList': ats, 'ImageContent': img_url, 'MsgType': 0, 'TextContent': '', 'ToUserName': to_wxid}
|
||||
]
|
||||
}
|
||||
url = self.base_url + "/message/SendImageMessage"
|
||||
url = self.base_url + '/message/SendImageMessage'
|
||||
return post_json(base_url=url, token=self.token, data=json_data)
|
||||
|
||||
def post_voice(self, to_wxid, voice_data, voice_forma, voice_duration):
|
||||
"""发送语音消息"""
|
||||
json_data = {
|
||||
"ToUserName": to_wxid,
|
||||
"VoiceData": voice_data,
|
||||
"VoiceFormat": voice_forma,
|
||||
"VoiceSecond": voice_duration
|
||||
'ToUserName': to_wxid,
|
||||
'VoiceData': voice_data,
|
||||
'VoiceFormat': voice_forma,
|
||||
'VoiceSecond': voice_duration,
|
||||
}
|
||||
url = self.base_url + "/message/SendVoice"
|
||||
url = self.base_url + '/message/SendVoice'
|
||||
return post_json(base_url=url, token=self.token, data=json_data)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def post_name_card(self, alias, to_wxid, nick_name, name_card_wxid, flag):
|
||||
"""发送名片消息"""
|
||||
param = {
|
||||
"CardAlias": alias,
|
||||
"CardFlag": flag,
|
||||
"CardNickName": nick_name,
|
||||
"CardWxId": name_card_wxid,
|
||||
"ToUserName": to_wxid
|
||||
'CardAlias': alias,
|
||||
'CardFlag': flag,
|
||||
'CardNickName': nick_name,
|
||||
'CardWxId': name_card_wxid,
|
||||
'ToUserName': to_wxid,
|
||||
}
|
||||
url = f"{self.base_url}/message/ShareCardMessage"
|
||||
url = f'{self.base_url}/message/ShareCardMessage'
|
||||
return post_json(base_url=url, token=self.token, data=param)
|
||||
|
||||
def post_emoji(self, to_wxid, emoji_md5, emoji_size:int=0):
|
||||
def post_emoji(self, to_wxid, emoji_md5, emoji_size: int = 0):
|
||||
"""发送emoji消息"""
|
||||
json_data = {
|
||||
"EmojiList": [
|
||||
{
|
||||
"EmojiMd5": emoji_md5,
|
||||
"EmojiSize": emoji_size,
|
||||
"ToUserName": to_wxid
|
||||
}
|
||||
]
|
||||
}
|
||||
url = f"{self.base_url}/message/SendEmojiMessage"
|
||||
json_data = {'EmojiList': [{'EmojiMd5': emoji_md5, 'EmojiSize': emoji_size, 'ToUserName': to_wxid}]}
|
||||
url = f'{self.base_url}/message/SendEmojiMessage'
|
||||
return post_json(base_url=url, token=self.token, data=json_data)
|
||||
|
||||
def post_app_msg(self, to_wxid,xml_data, contenttype:int=0):
|
||||
def post_app_msg(self, to_wxid, xml_data, contenttype: int = 0):
|
||||
"""发送appmsg消息"""
|
||||
json_data = {
|
||||
"AppList": [
|
||||
{
|
||||
"ContentType": contenttype,
|
||||
"ContentXML": xml_data,
|
||||
"ToUserName": to_wxid
|
||||
}
|
||||
]
|
||||
}
|
||||
url = f"{self.base_url}/message/SendAppMessage"
|
||||
json_data = {'AppList': [{'ContentType': contenttype, 'ContentXML': xml_data, 'ToUserName': to_wxid}]}
|
||||
url = f'{self.base_url}/message/SendAppMessage'
|
||||
return post_json(base_url=url, token=self.token, data=json_data)
|
||||
|
||||
|
||||
|
||||
def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time):
|
||||
"""撤回消息"""
|
||||
param = {
|
||||
"ClientMsgId": msg_id,
|
||||
"CreateTime": create_time,
|
||||
"NewMsgId": new_msg_id,
|
||||
"ToUserName": to_wxid
|
||||
}
|
||||
url = f"{self.base_url}/message/RevokeMsg"
|
||||
return post_json(base_url=url, token=self.token, data=param)
|
||||
param = {'ClientMsgId': msg_id, 'CreateTime': create_time, 'NewMsgId': new_msg_id, 'ToUserName': to_wxid}
|
||||
url = f'{self.base_url}/message/RevokeMsg'
|
||||
return post_json(base_url=url, token=self.token, data=param)
|
||||
|
||||
@@ -12,12 +12,9 @@ class UserApi:
|
||||
|
||||
return get_json(base_url=url, token=self.token)
|
||||
|
||||
def get_qr_code(self, recover:bool=True, style:int=8):
|
||||
def get_qr_code(self, recover: bool = True, style: int = 8):
|
||||
"""获取自己的二维码"""
|
||||
param = {
|
||||
"Recover": recover,
|
||||
"Style": style
|
||||
}
|
||||
param = {'Recover': recover, 'Style': style}
|
||||
url = f'{self.base_url}/user/GetMyQRCode'
|
||||
return post_json(base_url=url, token=self.token, data=param)
|
||||
|
||||
@@ -26,12 +23,8 @@ class UserApi:
|
||||
url = f'{self.base_url}/equipment/GetSafetyInfo'
|
||||
return post_json(base_url=url, token=self.token)
|
||||
|
||||
|
||||
|
||||
async def update_head_img(self, head_img_base64):
|
||||
async def update_head_img(self, head_img_base64):
|
||||
"""修改头像"""
|
||||
param = {
|
||||
"Base64": head_img_base64
|
||||
}
|
||||
param = {'Base64': head_img_base64}
|
||||
url = f'{self.base_url}/user/UploadHeadImage'
|
||||
return await async_request(base_url=url, token_key=self.token, json=param)
|
||||
return await async_request(base_url=url, token_key=self.token, json=param)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
|
||||
from libs.wechatpad_api.api.login import LoginApi
|
||||
from libs.wechatpad_api.api.friend import FriendApi
|
||||
from libs.wechatpad_api.api.message import MessageApi
|
||||
@@ -7,28 +6,26 @@ from libs.wechatpad_api.api.downloadpai import DownloadApi
|
||||
from libs.wechatpad_api.api.chatroom import ChatRoomApi
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class WeChatPadClient:
|
||||
def __init__(self,base_url, token):
|
||||
def __init__(self, base_url, token, logger=None):
|
||||
self._login_api = LoginApi(base_url, token)
|
||||
self._friend_api = FriendApi(base_url, token)
|
||||
self._message_api = MessageApi(base_url, token)
|
||||
self._user_api = UserApi(base_url, token)
|
||||
self._download_api = DownloadApi(base_url, token)
|
||||
self._chatroom_api = ChatRoomApi(base_url, token)
|
||||
self.logger = logger
|
||||
|
||||
def get_token(self,admin_key, day: int):
|
||||
'''获取token'''
|
||||
def get_token(self, admin_key, day: int):
|
||||
"""获取token"""
|
||||
return self._login_api.get_token(admin_key, day)
|
||||
|
||||
def get_login_qr(self, Proxy:str=""):
|
||||
def get_login_qr(self, Proxy: str = ''):
|
||||
"""登录二维码"""
|
||||
return self._login_api.get_login_qr(Proxy=Proxy)
|
||||
|
||||
def awaken_login(self, Proxy:str=""):
|
||||
'''唤醒登录'''
|
||||
def awaken_login(self, Proxy: str = ''):
|
||||
"""唤醒登录"""
|
||||
return self._login_api.wake_up_login(Proxy=Proxy)
|
||||
|
||||
def log_out(self):
|
||||
@@ -39,59 +36,57 @@ class WeChatPadClient:
|
||||
"""获取登录状态"""
|
||||
return self._login_api.get_login_status()
|
||||
|
||||
def send_text_message(self, to_wxid, message, ats: list=[]):
|
||||
def send_text_message(self, to_wxid, message, ats: list = []):
|
||||
"""发送文本消息"""
|
||||
return self._message_api.post_text(to_wxid, message, ats)
|
||||
return self._message_api.post_text(to_wxid, message, ats)
|
||||
|
||||
def send_image_message(self, to_wxid, img_url, ats: list=[]):
|
||||
def send_image_message(self, to_wxid, img_url, ats: list = []):
|
||||
"""发送图片消息"""
|
||||
return self._message_api.post_image(to_wxid, img_url, ats)
|
||||
return self._message_api.post_image(to_wxid, img_url, ats)
|
||||
|
||||
def send_voice_message(self, to_wxid, voice_data, voice_forma, voice_duration):
|
||||
"""发送音频消息"""
|
||||
return self._message_api.post_voice(to_wxid, voice_data, voice_forma, voice_duration)
|
||||
return self._message_api.post_voice(to_wxid, voice_data, voice_forma, voice_duration)
|
||||
|
||||
def send_app_message(self, to_wxid, app_message, type):
|
||||
"""发送app消息"""
|
||||
return self._message_api.post_app_msg(to_wxid, app_message, type)
|
||||
return self._message_api.post_app_msg(to_wxid, app_message, type)
|
||||
|
||||
def send_emoji_message(self, to_wxid, emoji_md5, emoji_size):
|
||||
"""发送emoji消息"""
|
||||
return self._message_api.post_emoji(to_wxid,emoji_md5,emoji_size)
|
||||
return self._message_api.post_emoji(to_wxid, emoji_md5, emoji_size)
|
||||
|
||||
def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time):
|
||||
"""撤回消息"""
|
||||
return self._message_api.revoke_msg(to_wxid, msg_id, new_msg_id, create_time)
|
||||
return self._message_api.revoke_msg(to_wxid, msg_id, new_msg_id, create_time)
|
||||
|
||||
def get_profile(self):
|
||||
"""获取用户信息"""
|
||||
return self._user_api.get_profile()
|
||||
|
||||
def get_qr_code(self, recover:bool=True, style:int=8):
|
||||
def get_qr_code(self, recover: bool = True, style: int = 8):
|
||||
"""获取用户二维码"""
|
||||
return self._user_api.get_qr_code(recover=recover, style=style)
|
||||
return self._user_api.get_qr_code(recover=recover, style=style)
|
||||
|
||||
def get_safety_info(self):
|
||||
"""获取设备信息"""
|
||||
return self._user_api.get_safety_info()
|
||||
return self._user_api.get_safety_info()
|
||||
|
||||
def update_head_img(self, head_img_base64):
|
||||
def update_head_img(self, head_img_base64):
|
||||
"""上传用户头像"""
|
||||
return self._user_api.update_head_img(head_img_base64)
|
||||
return self._user_api.update_head_img(head_img_base64)
|
||||
|
||||
def cdn_download(self, aeskey, file_type, file_url):
|
||||
"""cdn下载"""
|
||||
return self._download_api.send_download( aeskey, file_type, file_url)
|
||||
return self._download_api.send_download(aeskey, file_type, file_url)
|
||||
|
||||
def get_msg_voice(self,buf_id, length, msgid):
|
||||
def get_msg_voice(self, buf_id, length, msgid):
|
||||
"""下载语音"""
|
||||
return self._download_api.get_msg_voice(buf_id, length, msgid)
|
||||
|
||||
async def download_base64(self,url):
|
||||
async def download_base64(self, url):
|
||||
return await self._download_api.download_url_to_base64(download_url=url)
|
||||
|
||||
def get_chatroom_member_detail(self, chatroom_name):
|
||||
"""查看群成员详情"""
|
||||
return self._chatroom_api.get_chatroom_member_detail(chatroom_name)
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import requests
|
||||
import aiohttp
|
||||
|
||||
|
||||
def post_json(base_url, token, data=None):
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
|
||||
url = base_url + f'?key={token}'
|
||||
|
||||
@@ -18,14 +17,12 @@ def post_json(base_url, token, data=None):
|
||||
else:
|
||||
raise RuntimeError(response.text)
|
||||
except Exception as e:
|
||||
print(f"http请求失败, url={url}, exception={e}")
|
||||
print(f'http请求失败, url={url}, exception={e}')
|
||||
raise RuntimeError(str(e))
|
||||
|
||||
def get_json(base_url, token):
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
|
||||
def get_json(base_url, token):
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
|
||||
url = base_url + f'?key={token}'
|
||||
|
||||
@@ -39,21 +36,18 @@ def get_json(base_url, token):
|
||||
else:
|
||||
raise RuntimeError(response.text)
|
||||
except Exception as e:
|
||||
print(f"http请求失败, url={url}, exception={e}")
|
||||
print(f'http请求失败, url={url}, exception={e}')
|
||||
raise RuntimeError(str(e))
|
||||
|
||||
import aiohttp
|
||||
import asyncio
|
||||
|
||||
|
||||
async def async_request(
|
||||
base_url: str,
|
||||
token_key: str,
|
||||
method: str = 'POST',
|
||||
params: dict = None,
|
||||
# headers: dict = None,
|
||||
data: dict = None,
|
||||
json: dict = None
|
||||
base_url: str,
|
||||
token_key: str,
|
||||
method: str = 'POST',
|
||||
params: dict = None,
|
||||
# headers: dict = None,
|
||||
data: dict = None,
|
||||
json: dict = None,
|
||||
):
|
||||
"""
|
||||
通用异步请求函数
|
||||
@@ -67,18 +61,11 @@ async def async_request(
|
||||
:param json: JSON数据
|
||||
:return: 响应文本
|
||||
"""
|
||||
headers = {
|
||||
'Content-Type': 'application/json'
|
||||
}
|
||||
url = f"{base_url}?key={token_key}"
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
url = f'{base_url}?key={token_key}'
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.request(
|
||||
method=method,
|
||||
url=url,
|
||||
params=params,
|
||||
headers=headers,
|
||||
data=data,
|
||||
json=json
|
||||
method=method, url=url, params=params, headers=headers, data=data, json=json
|
||||
) as response:
|
||||
response.raise_for_status() # 如果状态码不是200,抛出异常
|
||||
result = await response.json()
|
||||
@@ -89,4 +76,3 @@ async def async_request(
|
||||
# return await result
|
||||
# else:
|
||||
# raise RuntimeError("请求失败",response.text)
|
||||
|
||||
|
||||
@@ -1,31 +1,34 @@
|
||||
import qrcode
|
||||
|
||||
|
||||
def print_green(text):
|
||||
print(f"\033[32m{text}\033[0m")
|
||||
print(f'\033[32m{text}\033[0m')
|
||||
|
||||
|
||||
def print_yellow(text):
|
||||
print(f"\033[33m{text}\033[0m")
|
||||
print(f'\033[33m{text}\033[0m')
|
||||
|
||||
|
||||
def print_red(text):
|
||||
print(f"\033[31m{text}\033[0m")
|
||||
print(f'\033[31m{text}\033[0m')
|
||||
|
||||
|
||||
def make_and_print_qr(url):
|
||||
"""生成并打印二维码
|
||||
|
||||
|
||||
Args:
|
||||
url: 需要生成二维码的URL字符串
|
||||
|
||||
|
||||
Returns:
|
||||
None
|
||||
|
||||
|
||||
功能:
|
||||
1. 在终端打印二维码的ASCII图形
|
||||
2. 同时提供在线二维码生成链接作为备选
|
||||
"""
|
||||
print_green("请扫描下方二维码登录")
|
||||
print_green('请扫描下方二维码登录')
|
||||
qr = qrcode.QRCode()
|
||||
qr.add_data(url)
|
||||
qr.make()
|
||||
qr.print_ascii(invert=True)
|
||||
print_green(f"也可以访问下方链接获取二维码:\nhttps://api.qrserver.com/v1/create-qr-code/?data={url}")
|
||||
|
||||
print_green(f'也可以访问下方链接获取二维码:\nhttps://api.qrserver.com/v1/create-qr-code/?data={url}')
|
||||
|
||||
@@ -3,11 +3,12 @@ from .WXBizMsgCrypt3 import WXBizMsgCrypt
|
||||
import base64
|
||||
import binascii
|
||||
import httpx
|
||||
import traceback
|
||||
from quart import Quart
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Callable, Dict, Any
|
||||
from .wecomevent import WecomEvent
|
||||
from pkg.platform.types import message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import aiofiles
|
||||
|
||||
|
||||
@@ -19,6 +20,7 @@ class WecomClient:
|
||||
token: str,
|
||||
EncodingAESKey: str,
|
||||
contacts_secret: str,
|
||||
logger: None,
|
||||
):
|
||||
self.corpid = corpid
|
||||
self.secret = secret
|
||||
@@ -28,8 +30,8 @@ class WecomClient:
|
||||
self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin'
|
||||
self.access_token = ''
|
||||
self.secret_for_contacts = contacts_secret
|
||||
self.logger = logger
|
||||
self.app = Quart(__name__)
|
||||
self.wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid)
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
@@ -55,6 +57,7 @@ class WecomClient:
|
||||
if 'access_token' in data:
|
||||
return data['access_token']
|
||||
else:
|
||||
await self.logger.error(f'获取accesstoken失败:{response.json()}')
|
||||
raise Exception(f'未获取access token: {data}')
|
||||
|
||||
async def get_users(self):
|
||||
@@ -126,6 +129,7 @@ class WecomClient:
|
||||
response = await client.post(url, json=params)
|
||||
data = response.json()
|
||||
except Exception as e:
|
||||
await self.logger.error(f'发送图片失败:{data}')
|
||||
raise Exception('Failed to send image: ' + str(e))
|
||||
|
||||
# 企业微信错误码40014和42001,代表accesstoken问题
|
||||
@@ -160,6 +164,7 @@ class WecomClient:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
return await self.send_private_msg(user_id, agent_id, content)
|
||||
if data['errcode'] != 0:
|
||||
await self.logger.error(f'发送消息失败:{data}')
|
||||
raise Exception('Failed to send message: ' + str(data))
|
||||
|
||||
async def handle_callback_request(self):
|
||||
@@ -171,17 +176,20 @@ class WecomClient:
|
||||
timestamp = request.args.get('timestamp')
|
||||
nonce = request.args.get('nonce')
|
||||
|
||||
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid)
|
||||
if request.method == 'GET':
|
||||
echostr = request.args.get('echostr')
|
||||
ret, reply_echo_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
|
||||
ret, reply_echo_str = wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
|
||||
if ret != 0:
|
||||
await self.logger.error('验证失败')
|
||||
raise Exception(f'验证失败,错误码: {ret}')
|
||||
return reply_echo_str
|
||||
|
||||
elif request.method == 'POST':
|
||||
encrypt_msg = await request.data
|
||||
ret, xml_msg = self.wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
|
||||
ret, xml_msg = wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
|
||||
if ret != 0:
|
||||
await self.logger.error('消息解密失败')
|
||||
raise Exception(f'消息解密失败,错误码: {ret}')
|
||||
|
||||
# 解析消息并处理
|
||||
@@ -193,6 +201,7 @@ class WecomClient:
|
||||
|
||||
return 'success'
|
||||
except Exception as e:
|
||||
await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
|
||||
return f'Error processing request: {str(e)}', 400
|
||||
|
||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||
@@ -291,6 +300,7 @@ class WecomClient:
|
||||
except binascii.Error as e:
|
||||
raise ValueError(f'Invalid base64 string: {str(e)}')
|
||||
else:
|
||||
await self.logger.error('Image对象出错')
|
||||
raise ValueError('image对象出错')
|
||||
|
||||
# 设置 multipart/form-data 格式的文件
|
||||
@@ -314,6 +324,7 @@ class WecomClient:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
media_id = await self.upload_to_work(image)
|
||||
if data.get('errcode', 0) != 0:
|
||||
await self.logger.error(f'上传图片失败:{data}')
|
||||
raise Exception('failed to upload file')
|
||||
|
||||
media_id = data.get('media_id')
|
||||
|
||||
@@ -8,12 +8,12 @@ from quart import Quart
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Callable
|
||||
from .wecomcsevent import WecomCSEvent
|
||||
from pkg.platform.types import message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import aiofiles
|
||||
|
||||
|
||||
class WecomCSClient:
|
||||
def __init__(self, corpid: str, secret: str, token: str, EncodingAESKey: str):
|
||||
def __init__(self, corpid: str, secret: str, token: str, EncodingAESKey: str, logger: None):
|
||||
self.corpid = corpid
|
||||
self.secret = secret
|
||||
self.access_token_for_contacts = ''
|
||||
@@ -21,6 +21,7 @@ class WecomCSClient:
|
||||
self.aes = EncodingAESKey
|
||||
self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin'
|
||||
self.access_token = ''
|
||||
self.logger = logger
|
||||
self.app = Quart(__name__)
|
||||
self.app.add_url_rule(
|
||||
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']
|
||||
@@ -186,6 +187,7 @@ class WecomCSClient:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
return await self.send_text_msg(open_kfid, external_userid, msgid, content)
|
||||
if data['errcode'] != 0:
|
||||
await self.logger.error(f'发送消息失败:{data}')
|
||||
raise Exception('Failed to send message')
|
||||
return data
|
||||
|
||||
@@ -224,7 +226,10 @@ class WecomCSClient:
|
||||
|
||||
return 'success'
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
if self.logger:
|
||||
await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
|
||||
else:
|
||||
traceback.print_exc()
|
||||
return f'Error processing request: {str(e)}', 400
|
||||
|
||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||
|
||||
32
main.py
32
main.py
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import argparse
|
||||
# LangBot 终端启动入口
|
||||
# 在此层级解决依赖项检查。
|
||||
# LangBot/main.py
|
||||
@@ -10,12 +11,16 @@ asciiart = r"""
|
||||
|____\__,_|_||_\__, |___/\___/\__|
|
||||
|___/
|
||||
|
||||
⭐️开源地址: https://github.com/RockChinQ/LangBot
|
||||
📖文档地址: https://docs.langbot.app
|
||||
⭐️ Open Source 开源地址: https://github.com/RockChinQ/LangBot
|
||||
📖 Documentation 文档地址: https://docs.langbot.app
|
||||
"""
|
||||
|
||||
|
||||
async def main_entry(loop: asyncio.AbstractEventLoop):
|
||||
parser = argparse.ArgumentParser(description='LangBot')
|
||||
parser.add_argument('--skip-plugin-deps-check', action='store_true', help='跳过插件依赖项检查', default=False)
|
||||
args = parser.parse_args()
|
||||
|
||||
print(asciiart)
|
||||
|
||||
import sys
|
||||
@@ -28,22 +33,27 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
|
||||
|
||||
if missing_deps:
|
||||
print('以下依赖包未安装,将自动安装,请完成后重启程序:')
|
||||
print(
|
||||
'These dependencies are missing, they will be installed automatically, please restart the program after completion:'
|
||||
)
|
||||
for dep in missing_deps:
|
||||
print('-', dep)
|
||||
await deps.install_deps(missing_deps)
|
||||
print('已自动安装缺失的依赖包,请重启程序。')
|
||||
print('The missing dependencies have been installed automatically, please restart the program.')
|
||||
sys.exit(0)
|
||||
|
||||
# check plugin deps
|
||||
await deps.precheck_plugin_deps()
|
||||
if not args.skip_plugin_deps_check:
|
||||
await deps.precheck_plugin_deps()
|
||||
|
||||
# 检查pydantic版本,如果没有 pydantic.v1,则把 pydantic 映射为 v1
|
||||
import pydantic.version
|
||||
# # 检查pydantic版本,如果没有 pydantic.v1,则把 pydantic 映射为 v1
|
||||
# import pydantic.version
|
||||
|
||||
if pydantic.version.VERSION < '2.0':
|
||||
import pydantic
|
||||
# if pydantic.version.VERSION < '2.0':
|
||||
# import pydantic
|
||||
|
||||
sys.modules['pydantic.v1'] = pydantic
|
||||
# sys.modules['pydantic.v1'] = pydantic
|
||||
|
||||
# 检查配置文件
|
||||
|
||||
@@ -53,6 +63,7 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
|
||||
|
||||
if generated_files:
|
||||
print('以下文件不存在,已自动生成:')
|
||||
print('Following files do not exist and have been automatically generated:')
|
||||
for file in generated_files:
|
||||
print('-', file)
|
||||
|
||||
@@ -69,9 +80,10 @@ if __name__ == '__main__':
|
||||
if sys.version_info < (3, 10, 1):
|
||||
print('需要 Python 3.10.1 及以上版本,当前 Python 版本为:', sys.version)
|
||||
input('按任意键退出...')
|
||||
print('Your Python version is not supported. Please exit the program by pressing any key.')
|
||||
exit(1)
|
||||
|
||||
# 检查本目录是否有main.py,且包含LangBot字符串
|
||||
# Check if the current directory is the LangBot project root directory
|
||||
invalid_pwd = False
|
||||
|
||||
if not os.path.exists('main.py'):
|
||||
@@ -84,6 +96,8 @@ if __name__ == '__main__':
|
||||
if invalid_pwd:
|
||||
print('请在 LangBot 项目根目录下以命令形式运行此程序。')
|
||||
input('按任意键退出...')
|
||||
print('Please run this program in the LangBot project root directory in command form.')
|
||||
print('Press any key to exit...')
|
||||
exit(1)
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
22
pkg/api/http/controller/groups/files.py
Normal file
22
pkg/api/http/controller/groups/files.py
Normal file
@@ -0,0 +1,22 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import quart
|
||||
import mimetypes
|
||||
|
||||
from .. import group
|
||||
|
||||
|
||||
@group.group_class('files', '/api/v1/files')
|
||||
class FilesRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/image/<image_key>', methods=['GET'], auth_type=group.AuthType.NONE)
|
||||
async def _(image_key: str) -> quart.Response:
|
||||
if not await self.ap.storage_mgr.storage_provider.exists(image_key):
|
||||
return quart.Response(status=404)
|
||||
|
||||
image_bytes = await self.ap.storage_mgr.storage_provider.load(image_key)
|
||||
mime_type = mimetypes.guess_type(image_key)[0]
|
||||
if mime_type is None:
|
||||
mime_type = 'image/jpeg'
|
||||
|
||||
return quart.Response(image_bytes, mimetype=mime_type)
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import quart
|
||||
|
||||
from .. import group
|
||||
from ... import group
|
||||
|
||||
|
||||
@group.group_class('pipelines', '/api/v1/pipelines')
|
||||
79
pkg/api/http/controller/groups/pipelines/webchat.py
Normal file
79
pkg/api/http/controller/groups/pipelines/webchat.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import quart
|
||||
|
||||
from ... import group
|
||||
|
||||
|
||||
@group.group_class('webchat', '/api/v1/pipelines/<pipeline_uuid>/chat')
|
||||
class WebChatDebugRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/send', methods=['POST'])
|
||||
async def send_message(pipeline_uuid: str) -> str:
|
||||
"""发送调试消息到流水线"""
|
||||
try:
|
||||
data = await quart.request.get_json()
|
||||
session_type = data.get('session_type', 'person')
|
||||
message_chain_obj = data.get('message', [])
|
||||
|
||||
if not message_chain_obj:
|
||||
return self.http_status(400, -1, 'message is required')
|
||||
|
||||
if session_type not in ['person', 'group']:
|
||||
return self.http_status(400, -1, 'session_type must be person or group')
|
||||
|
||||
webchat_adapter = self.ap.platform_mgr.webchat_proxy_bot.adapter
|
||||
|
||||
if not webchat_adapter:
|
||||
return self.http_status(404, -1, 'WebChat adapter not found')
|
||||
|
||||
result = await webchat_adapter.send_webchat_message(pipeline_uuid, session_type, message_chain_obj)
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'message': result,
|
||||
}
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
return self.http_status(500, -1, f'Internal server error: {str(e)}')
|
||||
|
||||
@self.route('/messages/<session_type>', methods=['GET'])
|
||||
async def get_messages(pipeline_uuid: str, session_type: str) -> str:
|
||||
"""获取调试消息历史"""
|
||||
try:
|
||||
if session_type not in ['person', 'group']:
|
||||
return self.http_status(400, -1, 'session_type must be person or group')
|
||||
|
||||
webchat_adapter = self.ap.platform_mgr.webchat_proxy_bot.adapter
|
||||
|
||||
if not webchat_adapter:
|
||||
return self.http_status(404, -1, 'WebChat adapter not found')
|
||||
|
||||
messages = webchat_adapter.get_webchat_messages(pipeline_uuid, session_type)
|
||||
|
||||
return self.success(data={'messages': messages})
|
||||
|
||||
except Exception as e:
|
||||
return self.http_status(500, -1, f'Internal server error: {str(e)}')
|
||||
|
||||
@self.route('/reset/<session_type>', methods=['POST'])
|
||||
async def reset_session(session_type: str) -> str:
|
||||
"""重置调试会话"""
|
||||
try:
|
||||
if session_type not in ['person', 'group']:
|
||||
return self.http_status(400, -1, 'session_type must be person or group')
|
||||
|
||||
webchat_adapter = None
|
||||
for bot in self.ap.platform_mgr.bots:
|
||||
if hasattr(bot.adapter, '__class__') and bot.adapter.__class__.__name__ == 'WebChatAdapter':
|
||||
webchat_adapter = bot.adapter
|
||||
break
|
||||
|
||||
if not webchat_adapter:
|
||||
return self.http_status(404, -1, 'WebChat adapter not found')
|
||||
|
||||
webchat_adapter.reset_debug_session(session_type)
|
||||
|
||||
return self.success(data={'message': 'Session reset successfully'})
|
||||
|
||||
except Exception as e:
|
||||
return self.http_status(500, -1, f'Internal server error: {str(e)}')
|
||||
@@ -29,3 +29,16 @@ class BotsRouterGroup(group.RouterGroup):
|
||||
elif quart.request.method == 'DELETE':
|
||||
await self.ap.bot_service.delete_bot(bot_uuid)
|
||||
return self.success()
|
||||
|
||||
@self.route('/<bot_uuid>/logs', methods=['POST'])
|
||||
async def _(bot_uuid: str) -> str:
|
||||
json_data = await quart.request.json
|
||||
from_index = json_data.get('from_index', -1)
|
||||
max_count = json_data.get('max_count', 10)
|
||||
logs, total_count = await self.ap.bot_service.list_event_logs(bot_uuid, from_index, max_count)
|
||||
return self.success(
|
||||
data={
|
||||
'logs': logs,
|
||||
'total_count': total_count,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import base64
|
||||
import quart
|
||||
|
||||
from .....core import taskmgr
|
||||
from .. import group
|
||||
from langbot_plugin.runtime.plugin.mgr import PluginInstallSource
|
||||
|
||||
|
||||
@group.group_class('plugins', '/api/v1/plugins')
|
||||
@@ -12,35 +13,22 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
plugins = self.ap.plugin_mgr.plugins()
|
||||
plugins = await self.ap.plugin_connector.list_plugins()
|
||||
|
||||
plugins_data = [plugin.model_dump() for plugin in plugins]
|
||||
|
||||
return self.success(data={'plugins': plugins_data})
|
||||
return self.success(data={'plugins': plugins})
|
||||
|
||||
@self.route(
|
||||
'/<author>/<plugin_name>/toggle',
|
||||
methods=['PUT'],
|
||||
auth_type=group.AuthType.USER_TOKEN,
|
||||
)
|
||||
async def _(author: str, plugin_name: str) -> str:
|
||||
data = await quart.request.json
|
||||
target_enabled = data.get('target_enabled')
|
||||
await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled)
|
||||
return self.success()
|
||||
|
||||
@self.route(
|
||||
'/<author>/<plugin_name>/update',
|
||||
'/<author>/<plugin_name>/upgrade',
|
||||
methods=['POST'],
|
||||
auth_type=group.AuthType.USER_TOKEN,
|
||||
)
|
||||
async def _(author: str, plugin_name: str) -> str:
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self.ap.plugin_mgr.update_plugin(plugin_name, task_context=ctx),
|
||||
self.ap.plugin_connector.upgrade_plugin(author, plugin_name, task_context=ctx),
|
||||
kind='plugin-operation',
|
||||
name=f'plugin-update-{plugin_name}',
|
||||
label=f'更新插件 {plugin_name}',
|
||||
name=f'plugin-upgrade-{plugin_name}',
|
||||
label=f'Upgrading plugin {plugin_name}',
|
||||
context=ctx,
|
||||
)
|
||||
return self.success(data={'task_id': wrapper.id})
|
||||
@@ -52,17 +40,17 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
)
|
||||
async def _(author: str, plugin_name: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
|
||||
plugin = await self.ap.plugin_connector.get_plugin_info(author, plugin_name)
|
||||
if plugin is None:
|
||||
return self.http_status(404, -1, 'plugin not found')
|
||||
return self.success(data={'plugin': plugin.model_dump()})
|
||||
return self.success(data={'plugin': plugin})
|
||||
elif quart.request.method == 'DELETE':
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self.ap.plugin_mgr.uninstall_plugin(plugin_name, task_context=ctx),
|
||||
self.ap.plugin_connector.delete_plugin(author, plugin_name, task_context=ctx),
|
||||
kind='plugin-operation',
|
||||
name=f'plugin-remove-{plugin_name}',
|
||||
label=f'删除插件 {plugin_name}',
|
||||
label=f'Removing plugin {plugin_name}',
|
||||
context=ctx,
|
||||
)
|
||||
|
||||
@@ -74,24 +62,19 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
auth_type=group.AuthType.USER_TOKEN,
|
||||
)
|
||||
async def _(author: str, plugin_name: str) -> quart.Response:
|
||||
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
|
||||
plugin = await self.ap.plugin_connector.get_plugin_info(author, plugin_name)
|
||||
if plugin is None:
|
||||
return self.http_status(404, -1, 'plugin not found')
|
||||
|
||||
if quart.request.method == 'GET':
|
||||
return self.success(data={'config': plugin.plugin_config})
|
||||
return self.success(data={'config': plugin['plugin_config']})
|
||||
elif quart.request.method == 'PUT':
|
||||
data = await quart.request.json
|
||||
|
||||
await self.ap.plugin_mgr.set_plugin_config(plugin, data)
|
||||
await self.ap.plugin_connector.set_plugin_config(author, plugin_name, data)
|
||||
|
||||
return self.success(data={})
|
||||
|
||||
@self.route('/reorder', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
data = await quart.request.json
|
||||
await self.ap.plugin_mgr.reorder_plugins(data.get('plugins'))
|
||||
return self.success()
|
||||
|
||||
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
data = await quart.request.json
|
||||
@@ -102,7 +85,47 @@ class PluginsRouterGroup(group.RouterGroup):
|
||||
self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx),
|
||||
kind='plugin-operation',
|
||||
name='plugin-install-github',
|
||||
label=f'安装插件 ...{short_source_str}',
|
||||
label=f'Installing plugin from github ...{short_source_str}',
|
||||
context=ctx,
|
||||
)
|
||||
|
||||
return self.success(data={'task_id': wrapper.id})
|
||||
|
||||
@self.route('/install/marketplace', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
data = await quart.request.json
|
||||
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self.ap.plugin_connector.install_plugin(PluginInstallSource.MARKETPLACE, data, task_context=ctx),
|
||||
kind='plugin-operation',
|
||||
name='plugin-install-marketplace',
|
||||
label=f'Installing plugin from marketplace ...{data}',
|
||||
context=ctx,
|
||||
)
|
||||
|
||||
return self.success(data={'task_id': wrapper.id})
|
||||
|
||||
@self.route('/install/local', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
file = (await quart.request.files).get('file')
|
||||
if file is None:
|
||||
return self.http_status(400, -1, 'file is required')
|
||||
|
||||
file_bytes = file.read()
|
||||
|
||||
file_base64 = base64.b64encode(file_bytes).decode('utf-8')
|
||||
|
||||
data = {
|
||||
'plugin_file': file_base64,
|
||||
}
|
||||
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self.ap.plugin_connector.install_plugin(PluginInstallSource.LOCAL, data, task_context=ctx),
|
||||
kind='plugin-operation',
|
||||
name='plugin-install-local',
|
||||
label=f'Installing plugin from local ...{file.filename}',
|
||||
context=ctx,
|
||||
)
|
||||
|
||||
|
||||
@@ -36,3 +36,11 @@ class LLMModelsRouterGroup(group.RouterGroup):
|
||||
await self.ap.model_service.delete_llm_model(model_uuid)
|
||||
|
||||
return self.success()
|
||||
|
||||
@self.route('/<model_uuid>/test', methods=['POST'])
|
||||
async def _(model_uuid: str) -> str:
|
||||
json_data = await quart.request.json
|
||||
|
||||
await self.ap.model_service.test_llm_model(model_uuid, json_data)
|
||||
|
||||
return self.success()
|
||||
|
||||
@@ -14,6 +14,11 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
'version': constants.semantic_version,
|
||||
'debug': constants.debug_mode,
|
||||
'enabled_platform_count': len(self.ap.platform_mgr.get_running_adapters()),
|
||||
'cloud_service_url': (
|
||||
self.ap.instance_config.data['plugin']['cloud_service_url']
|
||||
if 'cloud_service_url' in self.ap.instance_config.data['plugin']
|
||||
else 'https://space.langbot.app'
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
@@ -35,16 +40,7 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
|
||||
return self.success(data=task.to_dict())
|
||||
|
||||
@self.route('/reload', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
json_data = await quart.request.json
|
||||
|
||||
scope = json_data.get('scope')
|
||||
|
||||
await self.ap.reload(scope=scope)
|
||||
return self.success()
|
||||
|
||||
@self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
@self.route('/debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
if not constants.debug_mode:
|
||||
return self.http_status(403, 403, 'Forbidden')
|
||||
@@ -54,3 +50,39 @@ class SystemRouterGroup(group.RouterGroup):
|
||||
ap = self.ap
|
||||
|
||||
return self.success(data=exec(py_code, {'ap': ap}))
|
||||
|
||||
@self.route('/debug/tools/call', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||
async def _() -> str:
|
||||
if not constants.debug_mode:
|
||||
return self.http_status(403, 403, 'Forbidden')
|
||||
|
||||
data = await quart.request.json
|
||||
|
||||
return self.success(
|
||||
data=await self.ap.tool_mgr.execute_func_call(data['tool_name'], data['tool_parameters'])
|
||||
)
|
||||
|
||||
@self.route(
|
||||
'/debug/plugin/action',
|
||||
methods=['POST'],
|
||||
auth_type=group.AuthType.USER_TOKEN,
|
||||
)
|
||||
async def _() -> str:
|
||||
if not constants.debug_mode:
|
||||
return self.http_status(403, 403, 'Forbidden')
|
||||
|
||||
data = await quart.request.json
|
||||
|
||||
class AnoymousAction:
|
||||
value = 'anonymous_action'
|
||||
|
||||
def __init__(self, value: str):
|
||||
self.value = value
|
||||
|
||||
resp = await self.ap.plugin_connector.handler.call_action(
|
||||
AnoymousAction(data['action']),
|
||||
data['data'],
|
||||
timeout=data.get('timeout', 10),
|
||||
)
|
||||
|
||||
return self.success(data=resp)
|
||||
|
||||
@@ -13,10 +13,12 @@ from . import groups
|
||||
from . import group
|
||||
from .groups import provider as groups_provider
|
||||
from .groups import platform as groups_platform
|
||||
from .groups import pipelines as groups_pipelines
|
||||
|
||||
importutil.import_modules_in_pkg(groups)
|
||||
importutil.import_modules_in_pkg(groups_provider)
|
||||
importutil.import_modules_in_pkg(groups_platform)
|
||||
importutil.import_modules_in_pkg(groups_pipelines)
|
||||
|
||||
|
||||
class HTTPController:
|
||||
@@ -107,4 +109,8 @@ class HTTPController:
|
||||
elif path.endswith('.txt'):
|
||||
mimetype = 'text/plain'
|
||||
|
||||
return await quart.send_from_directory(frontend_path, path, mimetype=mimetype)
|
||||
response = await quart.send_from_directory(frontend_path, path, mimetype=mimetype)
|
||||
response.headers['Cache-Control'] = 'no-cache, no-store, must-revalidate'
|
||||
response.headers['Pragma'] = 'no-cache'
|
||||
response.headers['Expires'] = '0'
|
||||
return response
|
||||
|
||||
@@ -2,6 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import sqlalchemy
|
||||
import typing
|
||||
|
||||
from ....core import app
|
||||
from ....entity.persistence import bot as persistence_bot
|
||||
@@ -16,15 +17,19 @@ class BotService:
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
async def get_bots(self) -> list[dict]:
|
||||
async def get_bots(self, include_secret: bool = True) -> list[dict]:
|
||||
"""获取所有机器人"""
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot))
|
||||
|
||||
bots = result.all()
|
||||
|
||||
return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot) for bot in bots]
|
||||
masked_columns = []
|
||||
if not include_secret:
|
||||
masked_columns = ['adapter_config']
|
||||
|
||||
async def get_bot(self, bot_uuid: str) -> dict | None:
|
||||
return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot, masked_columns) for bot in bots]
|
||||
|
||||
async def get_bot(self, bot_uuid: str, include_secret: bool = True) -> dict | None:
|
||||
"""获取机器人"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
|
||||
@@ -35,7 +40,27 @@ class BotService:
|
||||
if bot is None:
|
||||
return None
|
||||
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot)
|
||||
masked_columns = []
|
||||
if not include_secret:
|
||||
masked_columns = ['adapter_config']
|
||||
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot, masked_columns)
|
||||
|
||||
async def get_runtime_bot_info(self, bot_uuid: str, include_secret: bool = True) -> dict:
|
||||
"""获取机器人运行时信息"""
|
||||
persistence_bot = await self.get_bot(bot_uuid, include_secret)
|
||||
if persistence_bot is None:
|
||||
raise Exception('Bot not found')
|
||||
|
||||
adapter_runtime_values = {}
|
||||
|
||||
runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
|
||||
if runtime_bot is not None:
|
||||
adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id
|
||||
|
||||
persistence_bot['adapter_runtime_values'] = adapter_runtime_values
|
||||
|
||||
return persistence_bot
|
||||
|
||||
async def create_bot(self, bot_data: dict) -> str:
|
||||
"""创建机器人"""
|
||||
@@ -92,9 +117,25 @@ class BotService:
|
||||
if runtime_bot.enable:
|
||||
await runtime_bot.run()
|
||||
|
||||
# update all conversation that use this bot
|
||||
for session in self.ap.sess_mgr.session_list:
|
||||
if session.using_conversation is not None and session.using_conversation.bot_uuid == bot_uuid:
|
||||
session.using_conversation = None
|
||||
|
||||
async def delete_bot(self, bot_uuid: str) -> None:
|
||||
"""删除机器人"""
|
||||
await self.ap.platform_mgr.remove_bot(bot_uuid)
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
|
||||
)
|
||||
|
||||
async def list_event_logs(
|
||||
self, bot_uuid: str, from_index: int, max_count: int
|
||||
) -> typing.Tuple[list[dict], int, int, int]:
|
||||
runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
|
||||
if runtime_bot is None:
|
||||
raise Exception('Bot not found')
|
||||
|
||||
logs, total_count = await runtime_bot.logger.get_logs(from_index, max_count)
|
||||
|
||||
return [log.to_json() for log in logs], total_count
|
||||
|
||||
@@ -6,6 +6,8 @@ import sqlalchemy
|
||||
from ....core import app
|
||||
from ....entity.persistence import model as persistence_model
|
||||
from ....entity.persistence import pipeline as persistence_pipeline
|
||||
from ....provider.modelmgr import requester as model_requester
|
||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||
|
||||
|
||||
class ModelsService:
|
||||
@@ -14,11 +16,19 @@ class ModelsService:
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
async def get_llm_models(self) -> list[dict]:
|
||||
async def get_llm_models(self, include_secret: bool = True) -> list[dict]:
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
|
||||
|
||||
models = result.all()
|
||||
return [self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) for model in models]
|
||||
|
||||
masked_columns = []
|
||||
if not include_secret:
|
||||
masked_columns = ['api_keys']
|
||||
|
||||
return [
|
||||
self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model, masked_columns)
|
||||
for model in models
|
||||
]
|
||||
|
||||
async def create_llm_model(self, model_data: dict) -> str:
|
||||
model_data['uuid'] = str(uuid.uuid4())
|
||||
@@ -78,3 +88,26 @@ class ModelsService:
|
||||
)
|
||||
|
||||
await self.ap.model_mgr.remove_llm_model(model_uuid)
|
||||
|
||||
async def test_llm_model(self, model_uuid: str, model_data: dict) -> None:
|
||||
runtime_llm_model: model_requester.RuntimeLLMModel | None = None
|
||||
|
||||
if model_uuid != '_':
|
||||
for model in self.ap.model_mgr.llm_models:
|
||||
if model.model_entity.uuid == model_uuid:
|
||||
runtime_llm_model = model
|
||||
break
|
||||
|
||||
if runtime_llm_model is None:
|
||||
raise Exception('model not found')
|
||||
|
||||
else:
|
||||
runtime_llm_model = await self.ap.model_mgr.init_runtime_llm_model(model_data)
|
||||
|
||||
await runtime_llm_model.requester.invoke_llm(
|
||||
query=None,
|
||||
model=runtime_llm_model,
|
||||
messages=[provider_message.Message(role='user', content='Hello, world!')],
|
||||
funcs=[],
|
||||
extra_args={},
|
||||
)
|
||||
|
||||
@@ -112,6 +112,11 @@ class PipelineService:
|
||||
await self.ap.pipeline_mgr.remove_pipeline(pipeline_uuid)
|
||||
await self.ap.pipeline_mgr.load_pipeline(pipeline)
|
||||
|
||||
# update all conversation that use this pipeline
|
||||
for session in self.ap.sess_mgr.session_list:
|
||||
if session.using_conversation is not None and session.using_conversation.pipeline_uuid == pipeline_uuid:
|
||||
session.using_conversation = None
|
||||
|
||||
async def delete_pipeline(self, pipeline_uuid: str) -> None:
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_pipeline.LegacyPipeline).where(
|
||||
|
||||
@@ -2,9 +2,12 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from ..core import app, entities as core_entities
|
||||
from . import entities, operator, errors
|
||||
from ..core import app
|
||||
from . import operator
|
||||
from ..utils import importutil
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
|
||||
# 引入所有算子以便注册
|
||||
from . import operators
|
||||
@@ -13,13 +16,11 @@ importutil.import_modules_in_pkg(operators)
|
||||
|
||||
|
||||
class CommandManager:
|
||||
"""命令管理器"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
cmd_list: list[operator.CommandOperator]
|
||||
"""
|
||||
运行时命令列表,扁平存储,各个对象包含对应的子节点引用
|
||||
Runtime command list, flat storage, each object contains a reference to the corresponding child node
|
||||
"""
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
@@ -55,43 +56,28 @@ class CommandManager:
|
||||
|
||||
async def _execute(
|
||||
self,
|
||||
context: entities.ExecuteContext,
|
||||
context: command_context.ExecuteContext,
|
||||
operator_list: list[operator.CommandOperator],
|
||||
operator: operator.CommandOperator = None,
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
"""执行命令"""
|
||||
|
||||
found = False
|
||||
if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
|
||||
for oper in operator_list:
|
||||
if (context.crt_params[0] == oper.name or context.crt_params[0] in oper.alias) and (
|
||||
oper.parent_class is None or oper.parent_class == operator.__class__
|
||||
):
|
||||
found = True
|
||||
command_list = await self.ap.plugin_connector.list_commands()
|
||||
|
||||
context.crt_command = context.crt_params[0]
|
||||
context.crt_params = context.crt_params[1:]
|
||||
|
||||
async for ret in self._execute(context, oper.children, oper):
|
||||
yield ret
|
||||
break
|
||||
|
||||
if not found: # 如果下一个参数未在此节点的子节点中找到,则执行此节点或者报错
|
||||
if operator is None:
|
||||
yield entities.CommandReturn(error=errors.CommandNotFoundError(context.crt_params[0]))
|
||||
else:
|
||||
if operator.lowest_privilege > context.privilege:
|
||||
yield entities.CommandReturn(error=errors.CommandPrivilegeError(operator.name))
|
||||
else:
|
||||
async for ret in operator.execute(context):
|
||||
yield ret
|
||||
for command in command_list:
|
||||
if command.metadata.name == context.command:
|
||||
async for ret in self.ap.plugin_connector.execute_command(context):
|
||||
yield ret
|
||||
break
|
||||
else:
|
||||
yield command_context.CommandReturn(error=command_errors.CommandNotFoundError(context.command))
|
||||
|
||||
async def execute(
|
||||
self,
|
||||
command_text: str,
|
||||
query: core_entities.Query,
|
||||
session: core_entities.Session,
|
||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
query: pipeline_query.Query,
|
||||
session: provider_session.Session,
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
"""执行命令"""
|
||||
|
||||
privilege = 1
|
||||
@@ -99,8 +85,8 @@ class CommandManager:
|
||||
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
|
||||
privilege = 2
|
||||
|
||||
ctx = entities.ExecuteContext(
|
||||
query=query,
|
||||
ctx = command_context.ExecuteContext(
|
||||
query_id=query.query_id,
|
||||
session=session,
|
||||
command_text=command_text,
|
||||
command='',
|
||||
@@ -110,5 +96,9 @@ class CommandManager:
|
||||
privilege=privilege,
|
||||
)
|
||||
|
||||
ctx.command = ctx.params[0]
|
||||
|
||||
ctx.shift()
|
||||
|
||||
async for ret in self._execute(ctx, self.cmd_list):
|
||||
yield ret
|
||||
|
||||
@@ -1,74 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from ..core import entities as core_entities
|
||||
from . import errors
|
||||
from ..platform.types import message as platform_message
|
||||
|
||||
|
||||
class CommandReturn(pydantic.BaseModel):
|
||||
"""命令返回值"""
|
||||
|
||||
text: typing.Optional[str] = None
|
||||
"""文本
|
||||
"""
|
||||
|
||||
image: typing.Optional[platform_message.Image] = None
|
||||
"""弃用"""
|
||||
|
||||
image_url: typing.Optional[str] = None
|
||||
"""图片链接
|
||||
"""
|
||||
|
||||
error: typing.Optional[errors.CommandError] = None
|
||||
"""错误
|
||||
"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ExecuteContext(pydantic.BaseModel):
|
||||
"""单次命令执行上下文"""
|
||||
|
||||
query: core_entities.Query
|
||||
"""本次消息的请求对象"""
|
||||
|
||||
session: core_entities.Session
|
||||
"""本次消息所属的会话对象"""
|
||||
|
||||
command_text: str
|
||||
"""命令完整文本"""
|
||||
|
||||
command: str
|
||||
"""命令名称"""
|
||||
|
||||
crt_command: str
|
||||
"""当前命令
|
||||
|
||||
多级命令中crt_command为当前命令,command为根命令。
|
||||
例如:!plugin on Webwlkr
|
||||
处理到plugin时,command为plugin,crt_command为plugin
|
||||
处理到on时,command为plugin,crt_command为on
|
||||
"""
|
||||
|
||||
params: list[str]
|
||||
"""命令参数
|
||||
|
||||
整个命令以空格分割后的参数列表
|
||||
"""
|
||||
|
||||
crt_params: list[str]
|
||||
"""当前命令参数
|
||||
|
||||
多级命令中crt_params为当前命令参数,params为根命令参数。
|
||||
例如:!plugin on Webwlkr
|
||||
处理到plugin时,params为['on', 'Webwlkr'],crt_params为['on', 'Webwlkr']
|
||||
处理到on时,params为['on', 'Webwlkr'],crt_params为['Webwlkr']
|
||||
"""
|
||||
|
||||
privilege: int
|
||||
"""发起人权限"""
|
||||
@@ -1,26 +0,0 @@
|
||||
class CommandError(Exception):
|
||||
def __init__(self, message: str = None):
|
||||
self.message = message
|
||||
|
||||
def __str__(self):
|
||||
return self.message
|
||||
|
||||
|
||||
class CommandNotFoundError(CommandError):
|
||||
def __init__(self, message: str = None):
|
||||
super().__init__('未知命令: ' + message)
|
||||
|
||||
|
||||
class CommandPrivilegeError(CommandError):
|
||||
def __init__(self, message: str = None):
|
||||
super().__init__('权限不足: ' + message)
|
||||
|
||||
|
||||
class ParamNotEnoughError(CommandError):
|
||||
def __init__(self, message: str = None):
|
||||
super().__init__('参数不足: ' + message)
|
||||
|
||||
|
||||
class CommandOperationError(CommandError):
|
||||
def __init__(self, message: str = None):
|
||||
super().__init__('操作失败: ' + message)
|
||||
@@ -4,7 +4,7 @@ import typing
|
||||
import abc
|
||||
|
||||
from ..core import app
|
||||
from . import entities
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context
|
||||
|
||||
|
||||
preregistered_operators: list[typing.Type[CommandOperator]] = []
|
||||
@@ -95,16 +95,18 @@ class CommandOperator(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
"""实现此方法以执行命令
|
||||
|
||||
支持多次yield以返回多个结果。
|
||||
例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。
|
||||
|
||||
Args:
|
||||
context (entities.ExecuteContext): 命令执行上下文
|
||||
context (command_context.ExecuteContext): 命令执行上下文
|
||||
|
||||
Yields:
|
||||
entities.CommandReturn: 命令返回封装
|
||||
command_context.CommandReturn: 命令返回封装
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -2,14 +2,17 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities, errors
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
|
||||
|
||||
@operator.operator_class(name='cmd', help='显示命令列表', usage='!cmd\n!cmd <命令名称>')
|
||||
class CmdOperator(operator.CommandOperator):
|
||||
"""命令列表"""
|
||||
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
"""执行"""
|
||||
if len(context.crt_params) == 0:
|
||||
reply_str = '当前所有命令: \n\n'
|
||||
@@ -20,7 +23,7 @@ class CmdOperator(operator.CommandOperator):
|
||||
|
||||
reply_str += '\n使用 !cmd <命令名称> 查看命令的详细帮助'
|
||||
|
||||
yield entities.CommandReturn(text=reply_str.strip())
|
||||
yield command_context.CommandReturn(text=reply_str.strip())
|
||||
|
||||
else:
|
||||
cmd_name = context.crt_params[0]
|
||||
@@ -33,9 +36,9 @@ class CmdOperator(operator.CommandOperator):
|
||||
break
|
||||
|
||||
if cmd is None:
|
||||
yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandNotFoundError(cmd_name))
|
||||
else:
|
||||
reply_str = f'{cmd.name}: {cmd.help}\n\n'
|
||||
reply_str += f'使用方法: \n{cmd.usage}'
|
||||
|
||||
yield entities.CommandReturn(text=reply_str.strip())
|
||||
yield command_context.CommandReturn(text=reply_str.strip())
|
||||
|
||||
@@ -2,23 +2,26 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities, errors
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
|
||||
|
||||
@operator.operator_class(name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all')
|
||||
class DelOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
if context.session.conversations:
|
||||
delete_index = 0
|
||||
if len(context.crt_params) > 0:
|
||||
try:
|
||||
delete_index = int(context.crt_params[0])
|
||||
except Exception:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandOperationError('索引必须是整数'))
|
||||
return
|
||||
|
||||
if delete_index < 0 or delete_index >= len(context.session.conversations):
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandOperationError('索引超出范围'))
|
||||
return
|
||||
|
||||
# 倒序
|
||||
@@ -29,15 +32,17 @@ class DelOperator(operator.CommandOperator):
|
||||
|
||||
del context.session.conversations[to_delete_index]
|
||||
|
||||
yield entities.CommandReturn(text=f'已删除对话: {delete_index}')
|
||||
yield command_context.CommandReturn(text=f'已删除对话: {delete_index}')
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
|
||||
|
||||
|
||||
@operator.operator_class(name='all', help='删除此会话的所有历史记录', parent_class=DelOperator)
|
||||
class DelAllOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
context.session.conversations = []
|
||||
context.session.using_conversation = None
|
||||
|
||||
yield entities.CommandReturn(text='已删除所有对话')
|
||||
yield command_context.CommandReturn(text='已删除所有对话')
|
||||
|
||||
@@ -1,19 +1,20 @@
|
||||
from __future__ import annotations
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from .. import operator, entities
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context
|
||||
|
||||
|
||||
@operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func')
|
||||
class FuncOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> AsyncGenerator[command_context.CommandReturn, None]:
|
||||
reply_str = '当前已启用的内容函数: \n\n'
|
||||
|
||||
index = 1
|
||||
|
||||
all_functions = await self.ap.tool_mgr.get_all_functions(
|
||||
plugin_enabled=True,
|
||||
)
|
||||
all_functions = await self.ap.tool_mgr.get_all_tools()
|
||||
|
||||
for func in all_functions:
|
||||
reply_str += '{}. {}:\n{}\n\n'.format(
|
||||
@@ -23,4 +24,4 @@ class FuncOperator(operator.CommandOperator):
|
||||
)
|
||||
index += 1
|
||||
|
||||
yield entities.CommandReturn(text=reply_str)
|
||||
yield command_context.CommandReturn(text=reply_str)
|
||||
|
||||
@@ -2,14 +2,17 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context
|
||||
|
||||
|
||||
@operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>')
|
||||
class HelpOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接:https://langbot.app'
|
||||
|
||||
help += '\n发送命令 !cmd 可查看命令列表'
|
||||
|
||||
yield entities.CommandReturn(text=help)
|
||||
yield command_context.CommandReturn(text=help)
|
||||
|
||||
@@ -3,26 +3,31 @@ from __future__ import annotations
|
||||
import typing
|
||||
|
||||
|
||||
from .. import operator, entities, errors
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
|
||||
|
||||
@operator.operator_class(name='last', help='切换到前一个对话', usage='!last')
|
||||
class LastOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
if context.session.conversations:
|
||||
# 找到当前会话的上一个会话
|
||||
for index in range(len(context.session.conversations) - 1, -1, -1):
|
||||
if context.session.conversations[index] == context.session.using_conversation:
|
||||
if index == 0:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了'))
|
||||
yield command_context.CommandReturn(
|
||||
error=command_errors.CommandOperationError('已经是第一个对话了')
|
||||
)
|
||||
return
|
||||
else:
|
||||
context.session.using_conversation = context.session.conversations[index - 1]
|
||||
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
yield entities.CommandReturn(
|
||||
yield command_context.CommandReturn(
|
||||
text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}'
|
||||
)
|
||||
return
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
|
||||
|
||||
@@ -2,19 +2,22 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities, errors
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
|
||||
|
||||
@operator.operator_class(name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>')
|
||||
class ListOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
page = 0
|
||||
|
||||
if len(context.crt_params) > 0:
|
||||
try:
|
||||
page = int(context.crt_params[0] - 1)
|
||||
except Exception:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandOperationError('页码应为整数'))
|
||||
return
|
||||
|
||||
record_per_page = 10
|
||||
@@ -45,4 +48,4 @@ class ListOperator(operator.CommandOperator):
|
||||
else:
|
||||
content += f'\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else "无内容"}'
|
||||
|
||||
yield entities.CommandReturn(text=f'第 {page + 1} 页 (时间倒序):\n{content}')
|
||||
yield command_context.CommandReturn(text=f'第 {page + 1} 页 (时间倒序):\n{content}')
|
||||
|
||||
@@ -2,26 +2,31 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities, errors
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
|
||||
|
||||
@operator.operator_class(name='next', help='切换到后一个对话', usage='!next')
|
||||
class NextOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
if context.session.conversations:
|
||||
# 找到当前会话的下一个会话
|
||||
for index in range(len(context.session.conversations)):
|
||||
if context.session.conversations[index] == context.session.using_conversation:
|
||||
if index == len(context.session.conversations) - 1:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了'))
|
||||
yield command_context.CommandReturn(
|
||||
error=command_errors.CommandOperationError('已经是最后一个对话了')
|
||||
)
|
||||
return
|
||||
else:
|
||||
context.session.using_conversation = context.session.conversations[index + 1]
|
||||
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||
|
||||
yield entities.CommandReturn(
|
||||
yield command_context.CommandReturn(
|
||||
text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}'
|
||||
)
|
||||
return
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
|
||||
|
||||
@@ -2,7 +2,8 @@ from __future__ import annotations
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
from .. import operator, entities, errors
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
|
||||
|
||||
@operator.operator_class(
|
||||
@@ -11,7 +12,9 @@ from .. import operator, entities, errors
|
||||
usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>',
|
||||
)
|
||||
class PluginOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
plugin_list = self.ap.plugin_mgr.plugins()
|
||||
reply_str = '所有插件({}):\n'.format(len(plugin_list))
|
||||
idx = 0
|
||||
@@ -27,32 +30,36 @@ class PluginOperator(operator.CommandOperator):
|
||||
|
||||
idx += 1
|
||||
|
||||
yield entities.CommandReturn(text=reply_str)
|
||||
yield command_context.CommandReturn(text=reply_str)
|
||||
|
||||
|
||||
@operator.operator_class(name='get', help='安装插件', privilege=2, parent_class=PluginOperator)
|
||||
class PluginGetOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址'))
|
||||
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件仓库地址'))
|
||||
else:
|
||||
repo = context.crt_params[0]
|
||||
|
||||
yield entities.CommandReturn(text='正在安装插件...')
|
||||
yield command_context.CommandReturn(text='正在安装插件...')
|
||||
|
||||
try:
|
||||
await self.ap.plugin_mgr.install_plugin(repo)
|
||||
yield entities.CommandReturn(text='插件安装成功,请重启程序以加载插件')
|
||||
yield command_context.CommandReturn(text='插件安装成功,请重启程序以加载插件')
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件安装失败: ' + str(e)))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('插件安装失败: ' + str(e)))
|
||||
|
||||
|
||||
@operator.operator_class(name='update', help='更新插件', privilege=2, parent_class=PluginOperator)
|
||||
class PluginUpdateOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
||||
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
|
||||
else:
|
||||
plugin_name = context.crt_params[0]
|
||||
|
||||
@@ -60,24 +67,26 @@ class PluginUpdateOperator(operator.CommandOperator):
|
||||
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
|
||||
|
||||
if plugin_container is not None:
|
||||
yield entities.CommandReturn(text='正在更新插件...')
|
||||
yield command_context.CommandReturn(text='正在更新插件...')
|
||||
await self.ap.plugin_mgr.update_plugin(plugin_name)
|
||||
yield entities.CommandReturn(text='插件更新成功,请重启程序以加载插件')
|
||||
yield command_context.CommandReturn(text='插件更新成功,请重启程序以加载插件')
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: 未找到插件'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: 未找到插件'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e)))
|
||||
|
||||
|
||||
@operator.operator_class(name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator)
|
||||
class PluginUpdateAllOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
try:
|
||||
plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()]
|
||||
|
||||
if plugins:
|
||||
yield entities.CommandReturn(text='正在更新插件...')
|
||||
yield command_context.CommandReturn(text='正在更新插件...')
|
||||
updated = []
|
||||
try:
|
||||
for plugin_name in plugins:
|
||||
@@ -85,20 +94,22 @@ class PluginUpdateAllOperator(operator.CommandOperator):
|
||||
updated.append(plugin_name)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
|
||||
yield entities.CommandReturn(text='已更新插件: {}'.format(', '.join(updated)))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e)))
|
||||
yield command_context.CommandReturn(text='已更新插件: {}'.format(', '.join(updated)))
|
||||
else:
|
||||
yield entities.CommandReturn(text='没有可更新的插件')
|
||||
yield command_context.CommandReturn(text='没有可更新的插件')
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e)))
|
||||
|
||||
|
||||
@operator.operator_class(name='del', help='删除插件', privilege=2, parent_class=PluginOperator)
|
||||
class PluginDelOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
||||
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
|
||||
else:
|
||||
plugin_name = context.crt_params[0]
|
||||
|
||||
@@ -106,51 +117,55 @@ class PluginDelOperator(operator.CommandOperator):
|
||||
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
|
||||
|
||||
if plugin_container is not None:
|
||||
yield entities.CommandReturn(text='正在删除插件...')
|
||||
yield command_context.CommandReturn(text='正在删除插件...')
|
||||
await self.ap.plugin_mgr.uninstall_plugin(plugin_name)
|
||||
yield entities.CommandReturn(text='插件删除成功,请重启程序以加载插件')
|
||||
yield command_context.CommandReturn(text='插件删除成功,请重启程序以加载插件')
|
||||
else:
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件删除失败: 未找到插件'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('插件删除失败: 未找到插件'))
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件删除失败: ' + str(e)))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('插件删除失败: ' + str(e)))
|
||||
|
||||
|
||||
@operator.operator_class(name='on', help='启用插件', privilege=2, parent_class=PluginOperator)
|
||||
class PluginEnableOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
||||
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
|
||||
else:
|
||||
plugin_name = context.crt_params[0]
|
||||
|
||||
try:
|
||||
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True):
|
||||
yield entities.CommandReturn(text='已启用插件: {}'.format(plugin_name))
|
||||
yield command_context.CommandReturn(text='已启用插件: {}'.format(plugin_name))
|
||||
else:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
|
||||
yield command_context.CommandReturn(
|
||||
error=command_errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e)))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('插件状态修改失败: ' + str(e)))
|
||||
|
||||
|
||||
@operator.operator_class(name='off', help='禁用插件', privilege=2, parent_class=PluginOperator)
|
||||
class PluginDisableOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
if len(context.crt_params) == 0:
|
||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
||||
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
|
||||
else:
|
||||
plugin_name = context.crt_params[0]
|
||||
|
||||
try:
|
||||
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False):
|
||||
yield entities.CommandReturn(text='已禁用插件: {}'.format(plugin_name))
|
||||
yield command_context.CommandReturn(text='已禁用插件: {}'.format(plugin_name))
|
||||
else:
|
||||
yield entities.CommandReturn(
|
||||
error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
|
||||
yield command_context.CommandReturn(
|
||||
error=command_errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
|
||||
)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e)))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('插件状态修改失败: ' + str(e)))
|
||||
|
||||
@@ -2,19 +2,22 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities, errors
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
|
||||
|
||||
@operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt')
|
||||
class PromptOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
"""执行"""
|
||||
if context.session.using_conversation is None:
|
||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
|
||||
else:
|
||||
reply_str = '当前对话所有内容:\n\n'
|
||||
|
||||
for msg in context.session.using_conversation.messages:
|
||||
reply_str += f'{msg.role}: {msg.content}\n'
|
||||
|
||||
yield entities.CommandReturn(text=reply_str)
|
||||
yield command_context.CommandReturn(text=reply_str)
|
||||
|
||||
@@ -2,15 +2,18 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities, errors
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||
|
||||
|
||||
@operator.operator_class(name='resend', help='重发当前会话的最后一条消息', usage='!resend')
|
||||
class ResendOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
# 回滚到最后一条用户message前
|
||||
if context.session.using_conversation is None:
|
||||
yield entities.CommandReturn(error=errors.CommandError('当前没有对话'))
|
||||
yield command_context.CommandReturn(error=command_errors.CommandError('当前没有对话'))
|
||||
else:
|
||||
conv_msg = context.session.using_conversation.messages
|
||||
|
||||
@@ -23,4 +26,4 @@ class ResendOperator(operator.CommandOperator):
|
||||
conv_msg.pop()
|
||||
|
||||
# 不重发了,提示用户已删除就行了
|
||||
yield entities.CommandReturn(text='已删除最后一次请求记录')
|
||||
yield command_context.CommandReturn(text='已删除最后一次请求记录')
|
||||
|
||||
@@ -2,13 +2,16 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context
|
||||
|
||||
|
||||
@operator.operator_class(name='reset', help='重置当前会话', usage='!reset')
|
||||
class ResetOperator(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
"""执行"""
|
||||
context.session.using_conversation = None
|
||||
|
||||
yield entities.CommandReturn(text='已重置当前会话')
|
||||
yield command_context.CommandReturn(text='已重置当前会话')
|
||||
|
||||
@@ -1,11 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities
|
||||
|
||||
|
||||
@operator.operator_class(name='update', help='更新程序', usage='!update', privilege=2)
|
||||
class UpdateCommand(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
yield entities.CommandReturn(text='不再支持通过命令更新,请查看 LangBot 文档。')
|
||||
@@ -2,12 +2,15 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from .. import operator, entities
|
||||
from .. import operator
|
||||
from langbot_plugin.api.entities.builtin.command import context as command_context
|
||||
|
||||
|
||||
@operator.operator_class(name='version', help='显示版本信息', usage='!version')
|
||||
class VersionCommand(operator.CommandOperator):
|
||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
||||
async def execute(
|
||||
self, context: command_context.ExecuteContext
|
||||
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||
reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}'
|
||||
|
||||
try:
|
||||
@@ -16,4 +19,4 @@ class VersionCommand(operator.CommandOperator):
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
yield entities.CommandReturn(text=reply_str.strip())
|
||||
yield command_context.CommandReturn(text=reply_str.strip())
|
||||
|
||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
||||
import logging
|
||||
import asyncio
|
||||
import traceback
|
||||
import sys
|
||||
import os
|
||||
|
||||
from ..platform import botmgr as im_mgr
|
||||
@@ -12,7 +11,7 @@ from ..provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ..provider.tools import toolmgr as llm_tool_mgr
|
||||
from ..config import manager as config_mgr
|
||||
from ..command import cmdmgr
|
||||
from ..plugin import manager as plugin_mgr
|
||||
from ..plugin import connector as plugin_connector
|
||||
from ..pipeline import pool
|
||||
from ..pipeline import controller, pipelinemgr
|
||||
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
|
||||
@@ -23,7 +22,8 @@ from ..api.http.service import model as model_service
|
||||
from ..api.http.service import pipeline as pipeline_service
|
||||
from ..api.http.service import bot as bot_service
|
||||
from ..discover import engine as discover_engine
|
||||
from ..utils import logcache, ip
|
||||
from ..storage import mgr as storagemgr
|
||||
from ..utils import logcache
|
||||
from . import taskmgr
|
||||
from . import entities as core_entities
|
||||
|
||||
@@ -74,7 +74,7 @@ class Application:
|
||||
|
||||
# =========================
|
||||
|
||||
plugin_mgr: plugin_mgr.PluginManager = None
|
||||
plugin_connector: plugin_connector.PluginRuntimeConnector = None
|
||||
|
||||
query_pool: pool.QueryPool = None
|
||||
|
||||
@@ -96,6 +96,8 @@ class Application:
|
||||
|
||||
log_cache: logcache.LogCache = None
|
||||
|
||||
storage_mgr: storagemgr.StorageMgr = None
|
||||
|
||||
# ========= HTTP Services =========
|
||||
|
||||
user_service: user_service.UserService = None
|
||||
@@ -114,7 +116,7 @@ class Application:
|
||||
|
||||
async def run(self):
|
||||
try:
|
||||
await self.plugin_mgr.initialize_plugins()
|
||||
await self.plugin_connector.initialize_plugins()
|
||||
|
||||
# 后续可能会允许动态重启其他任务
|
||||
# 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程
|
||||
@@ -154,89 +156,32 @@ class Application:
|
||||
self.logger.error(f'应用运行致命异常: {e}')
|
||||
self.logger.debug(f'Traceback: {traceback.format_exc()}')
|
||||
|
||||
def dispose(self):
|
||||
self.plugin_connector.dispose()
|
||||
|
||||
async def print_web_access_info(self):
|
||||
"""打印访问 webui 的提示"""
|
||||
|
||||
if not os.path.exists(os.path.join('.', 'web/out')):
|
||||
self.logger.warning('WebUI 文件缺失,请根据文档获取:https://docs.langbot.app/webui/intro.html')
|
||||
self.logger.warning('WebUI 文件缺失,请根据文档部署:https://docs.langbot.app/zh')
|
||||
self.logger.warning(
|
||||
'WebUI files are missing, please deploy according to the documentation: https://docs.langbot.app/en'
|
||||
)
|
||||
return
|
||||
|
||||
host_ip = '127.0.0.1'
|
||||
|
||||
public_ip = await ip.get_myip()
|
||||
|
||||
port = self.instance_config.data['api']['port']
|
||||
|
||||
tips = f"""
|
||||
=======================================
|
||||
✨ 您可通过以下方式访问管理面板
|
||||
✨ Access WebUI / 访问管理面板
|
||||
|
||||
🏠 本地地址:http://{host_ip}:{port}/
|
||||
🌐 公网地址:http://{public_ip}:{port}/
|
||||
🏠 Local Address: http://{host_ip}:{port}/
|
||||
🌐 Public Address: http://<Your Public IP>:{port}/
|
||||
|
||||
📌 如果您在容器中运行此程序,请确保容器的 {port} 端口已对外暴露
|
||||
🔗 若要使用公网地址访问,请阅读以下须知
|
||||
1. 公网地址仅供参考,请以您的主机公网 IP 为准;
|
||||
2. 要使用公网地址访问,请确保您的主机具有公网 IP,并且系统防火墙已放行 {port} 端口;
|
||||
|
||||
🤯 WebUI 仍处于 Beta 测试阶段,如有问题或建议请反馈到 https://github.com/RockChinQ/LangBot/issues
|
||||
📌 Running this program in a container? Please ensure that the {port} port is exposed
|
||||
=======================================
|
||||
""".strip()
|
||||
for line in tips.split('\n'):
|
||||
self.logger.info(line)
|
||||
|
||||
async def reload(
|
||||
self,
|
||||
scope: core_entities.LifecycleControlScope,
|
||||
):
|
||||
match scope:
|
||||
case core_entities.LifecycleControlScope.PLATFORM.value:
|
||||
self.logger.info('执行热重载 scope=' + scope)
|
||||
await self.platform_mgr.shutdown()
|
||||
|
||||
self.platform_mgr = im_mgr.PlatformManager(self)
|
||||
|
||||
await self.platform_mgr.initialize()
|
||||
|
||||
self.task_mgr.create_task(
|
||||
self.platform_mgr.run(),
|
||||
name='platform-manager',
|
||||
scopes=[
|
||||
core_entities.LifecycleControlScope.APPLICATION,
|
||||
core_entities.LifecycleControlScope.PLATFORM,
|
||||
],
|
||||
)
|
||||
case core_entities.LifecycleControlScope.PLUGIN.value:
|
||||
self.logger.info('执行热重载 scope=' + scope)
|
||||
await self.plugin_mgr.destroy_plugins()
|
||||
|
||||
# 删除 sys.module 中所有的 plugins/* 下的模块
|
||||
for mod in list(sys.modules.keys()):
|
||||
if mod.startswith('plugins.'):
|
||||
del sys.modules[mod]
|
||||
|
||||
self.plugin_mgr = plugin_mgr.PluginManager(self)
|
||||
await self.plugin_mgr.initialize()
|
||||
|
||||
await self.plugin_mgr.initialize_plugins()
|
||||
|
||||
await self.plugin_mgr.load_plugins()
|
||||
await self.plugin_mgr.initialize_plugins()
|
||||
case core_entities.LifecycleControlScope.PROVIDER.value:
|
||||
self.logger.info('执行热重载 scope=' + scope)
|
||||
|
||||
await self.tool_mgr.shutdown()
|
||||
|
||||
llm_model_mgr_inst = llm_model_mgr.ModelManager(self)
|
||||
await llm_model_mgr_inst.initialize()
|
||||
self.model_mgr = llm_model_mgr_inst
|
||||
|
||||
llm_session_mgr_inst = llm_session_mgr.SessionManager(self)
|
||||
await llm_session_mgr_inst.initialize()
|
||||
self.sess_mgr = llm_session_mgr_inst
|
||||
|
||||
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(self)
|
||||
await llm_tool_mgr_inst.initialize()
|
||||
self.tool_mgr = llm_tool_mgr_inst
|
||||
case _:
|
||||
pass
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from __future__ import print_function
|
||||
from __future__ import annotations
|
||||
|
||||
import traceback
|
||||
import asyncio
|
||||
@@ -51,8 +51,8 @@ async def main(loop: asyncio.AbstractEventLoop):
|
||||
import signal
|
||||
|
||||
def signal_handler(sig, frame):
|
||||
app_inst.dispose()
|
||||
print('[Signal] 程序退出.')
|
||||
# ap.shutdown()
|
||||
os._exit(0)
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
|
||||
@@ -74,5 +74,5 @@ async def precheck_plugin_deps():
|
||||
if 'requirements.txt' in os.listdir(subdir):
|
||||
pkgmgr.install_requirements(
|
||||
os.path.join(subdir, 'requirements.txt'),
|
||||
extra_params=['-q', '-q', '-q'],
|
||||
extra_params=[],
|
||||
)
|
||||
|
||||
@@ -1,18 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
import typing
|
||||
import datetime
|
||||
import asyncio
|
||||
|
||||
import pydantic.v1 as pydantic
|
||||
|
||||
from ..provider import entities as llm_entities
|
||||
from ..provider.modelmgr import requester
|
||||
from ..provider.tools import entities as tools_entities
|
||||
from ..platform import adapter as msadapter
|
||||
from ..platform.types import message as platform_message
|
||||
from ..platform.types import events as platform_events
|
||||
|
||||
|
||||
class LifecycleControlScope(enum.Enum):
|
||||
@@ -20,151 +8,3 @@ class LifecycleControlScope(enum.Enum):
|
||||
PLATFORM = 'platform'
|
||||
PLUGIN = 'plugin'
|
||||
PROVIDER = 'provider'
|
||||
|
||||
|
||||
class LauncherTypes(enum.Enum):
|
||||
"""一个请求的发起者类型"""
|
||||
|
||||
PERSON = 'person'
|
||||
"""私聊"""
|
||||
|
||||
GROUP = 'group'
|
||||
"""群聊"""
|
||||
|
||||
|
||||
class Query(pydantic.BaseModel):
|
||||
"""一次请求的信息封装"""
|
||||
|
||||
query_id: int
|
||||
"""请求ID,添加进请求池时生成"""
|
||||
|
||||
launcher_type: LauncherTypes
|
||||
"""会话类型,platform处理阶段设置"""
|
||||
|
||||
launcher_id: typing.Union[int, str]
|
||||
"""会话ID,platform处理阶段设置"""
|
||||
|
||||
sender_id: typing.Union[int, str]
|
||||
"""发送者ID,platform处理阶段设置"""
|
||||
|
||||
message_event: platform_events.MessageEvent
|
||||
"""事件,platform收到的原始事件"""
|
||||
|
||||
message_chain: platform_message.MessageChain
|
||||
"""消息链,platform收到的原始消息链"""
|
||||
|
||||
bot_uuid: typing.Optional[str] = None
|
||||
"""机器人UUID。"""
|
||||
|
||||
pipeline_uuid: typing.Optional[str] = None
|
||||
"""流水线UUID。"""
|
||||
|
||||
pipeline_config: typing.Optional[dict[str, typing.Any]] = None
|
||||
"""流水线配置,由 Pipeline 在运行开始时设置。"""
|
||||
|
||||
adapter: msadapter.MessagePlatformAdapter
|
||||
"""消息平台适配器对象,单个app中可能启用了多个消息平台适配器,此对象表明发起此query的适配器"""
|
||||
|
||||
session: typing.Optional[Session] = None
|
||||
"""会话对象,由前置处理器阶段设置"""
|
||||
|
||||
messages: typing.Optional[list[llm_entities.Message]] = []
|
||||
"""历史消息列表,由前置处理器阶段设置"""
|
||||
|
||||
prompt: typing.Optional[llm_entities.Prompt] = None
|
||||
"""情景预设内容,由前置处理器阶段设置"""
|
||||
|
||||
user_message: typing.Optional[llm_entities.Message] = None
|
||||
"""此次请求的用户消息对象,由前置处理器阶段设置"""
|
||||
|
||||
variables: typing.Optional[dict[str, typing.Any]] = None
|
||||
"""变量,由前置处理器阶段设置。在prompt中嵌入或由 Runner 传递到 LLMOps 平台。"""
|
||||
|
||||
use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None
|
||||
"""使用的对话模型,由前置处理器阶段设置"""
|
||||
|
||||
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
|
||||
"""使用的函数,由前置处理器阶段设置"""
|
||||
|
||||
resp_messages: (
|
||||
typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]]
|
||||
) = []
|
||||
"""由Process阶段生成的回复消息对象列表"""
|
||||
|
||||
resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None
|
||||
"""回复消息链,从resp_messages包装而得"""
|
||||
|
||||
# ======= 内部保留 =======
|
||||
current_stage: typing.Optional['pkg.pipeline.pipelinemgr.StageInstContainer'] = None
|
||||
"""当前所处阶段"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
# ========== 插件可调用的 API(请求 API) ==========
|
||||
|
||||
def set_variable(self, key: str, value: typing.Any):
|
||||
"""设置变量"""
|
||||
if self.variables is None:
|
||||
self.variables = {}
|
||||
self.variables[key] = value
|
||||
|
||||
def get_variable(self, key: str) -> typing.Any:
|
||||
"""获取变量"""
|
||||
if self.variables is None:
|
||||
return None
|
||||
return self.variables.get(key)
|
||||
|
||||
def get_variables(self) -> dict[str, typing.Any]:
|
||||
"""获取所有变量"""
|
||||
if self.variables is None:
|
||||
return {}
|
||||
return self.variables
|
||||
|
||||
|
||||
class Conversation(pydantic.BaseModel):
|
||||
"""对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation,但只有一个当前使用的 Conversation"""
|
||||
|
||||
prompt: llm_entities.Prompt
|
||||
|
||||
messages: list[llm_entities.Message]
|
||||
|
||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None
|
||||
|
||||
use_funcs: typing.Optional[list[tools_entities.LLMFunction]]
|
||||
|
||||
uuid: typing.Optional[str] = None
|
||||
"""该对话的 uuid,在创建时不会自动生成。而是当使用 Dify API 等由外部管理对话信息的服务时,用于绑定外部的会话。具体如何使用,取决于 Runner。"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class Session(pydantic.BaseModel):
|
||||
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""
|
||||
|
||||
launcher_type: LauncherTypes
|
||||
|
||||
launcher_id: typing.Union[int, str]
|
||||
|
||||
sender_id: typing.Optional[typing.Union[int, str]] = 0
|
||||
|
||||
use_prompt_name: typing.Optional[str] = 'default'
|
||||
|
||||
using_conversation: typing.Optional[Conversation] = None
|
||||
|
||||
conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list)
|
||||
|
||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
||||
|
||||
semaphore: typing.Optional[asyncio.Semaphore] = None
|
||||
"""当前会话的信号量,用于限制并发"""
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from .. import stage, app
|
||||
from ...utils import version, proxy, announce
|
||||
from ...pipeline import pool, controller, pipelinemgr
|
||||
from ...plugin import manager as plugin_mgr
|
||||
from ...plugin import connector as plugin_connector
|
||||
from ...command import cmdmgr
|
||||
from ...provider.session import sessionmgr as llm_session_mgr
|
||||
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
||||
@@ -17,6 +18,7 @@ from ...api.http.service import model as model_service
|
||||
from ...api.http.service import pipeline as pipeline_service
|
||||
from ...api.http.service import bot as bot_service
|
||||
from ...discover import engine as discover_engine
|
||||
from ...storage import mgr as storagemgr
|
||||
from ...utils import logcache
|
||||
from .. import taskmgr
|
||||
|
||||
@@ -50,14 +52,21 @@ class BuildAppStage(stage.BootingStage):
|
||||
log_cache = logcache.LogCache()
|
||||
ap.log_cache = log_cache
|
||||
|
||||
storage_mgr_inst = storagemgr.StorageMgr(ap)
|
||||
await storage_mgr_inst.initialize()
|
||||
ap.storage_mgr = storage_mgr_inst
|
||||
|
||||
persistence_mgr_inst = persistencemgr.PersistenceManager(ap)
|
||||
ap.persistence_mgr = persistence_mgr_inst
|
||||
await persistence_mgr_inst.initialize()
|
||||
|
||||
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
|
||||
await plugin_mgr_inst.initialize()
|
||||
ap.plugin_mgr = plugin_mgr_inst
|
||||
await plugin_mgr_inst.load_plugins()
|
||||
async def runtime_disconnect_callback(connector: plugin_connector.PluginRuntimeConnector) -> None:
|
||||
await asyncio.sleep(3)
|
||||
await plugin_connector_inst.initialize()
|
||||
|
||||
plugin_connector_inst = plugin_connector.PluginRuntimeConnector(ap, runtime_disconnect_callback)
|
||||
await plugin_connector_inst.initialize()
|
||||
ap.plugin_connector = plugin_connector_inst
|
||||
|
||||
cmd_mgr_inst = cmdmgr.CommandManager(ap)
|
||||
await cmd_mgr_inst.initialize()
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
from .. import stage, app, note
|
||||
from ...utils import importutil
|
||||
|
||||
@@ -20,11 +22,15 @@ class ShowNotesStage(stage.BootingStage):
|
||||
try:
|
||||
note_inst = note_cls(ap)
|
||||
if await note_inst.need_show():
|
||||
async for ret in note_inst.yield_note():
|
||||
if not ret:
|
||||
continue
|
||||
msg, level = ret
|
||||
if msg:
|
||||
ap.logger.log(level, msg)
|
||||
|
||||
async def ayield_note(note_inst: note.LaunchNote):
|
||||
async for ret in note_inst.yield_note():
|
||||
if not ret:
|
||||
continue
|
||||
msg, level = ret
|
||||
if msg:
|
||||
ap.logger.log(level, msg)
|
||||
|
||||
asyncio.create_task(ayield_note(note_inst))
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
9
pkg/entity/errors/platform.py
Normal file
9
pkg/entity/errors/platform.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class AdapterNotFoundError(Exception):
|
||||
def __init__(self, adapter_name: str):
|
||||
self.adapter_name = adapter_name
|
||||
|
||||
def __str__(self):
|
||||
return f'Adapter {self.adapter_name} not found'
|
||||
9
pkg/entity/errors/provider.py
Normal file
9
pkg/entity/errors/provider.py
Normal file
@@ -0,0 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class RequesterNotFoundError(Exception):
|
||||
def __init__(self, requester_name: str):
|
||||
self.requester_name = requester_name
|
||||
|
||||
def __str__(self):
|
||||
return f'Requester {self.requester_name} not found'
|
||||
22
pkg/entity/persistence/bstorage.py
Normal file
22
pkg/entity/persistence/bstorage.py
Normal file
@@ -0,0 +1,22 @@
|
||||
import sqlalchemy
|
||||
|
||||
from .base import Base
|
||||
|
||||
|
||||
class BinaryStorage(Base):
|
||||
"""Current for plugin use only"""
|
||||
|
||||
__tablename__ = 'binary_storages'
|
||||
|
||||
unique_key = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
|
||||
key = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
owner_type = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
owner = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
value = sqlalchemy.Column(sqlalchemy.LargeBinary, nullable=False)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
nullable=False,
|
||||
server_default=sqlalchemy.func.now(),
|
||||
onupdate=sqlalchemy.func.now(),
|
||||
)
|
||||
@@ -13,6 +13,8 @@ class PluginSetting(Base):
|
||||
enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True)
|
||||
priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
|
||||
config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict)
|
||||
install_source = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, default='github')
|
||||
install_info = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||
updated_at = sqlalchemy.Column(
|
||||
sqlalchemy.DateTime,
|
||||
|
||||
@@ -44,46 +44,6 @@ class PersistenceManager:
|
||||
|
||||
await self.create_tables()
|
||||
|
||||
async def create_tables(self):
|
||||
# create tables
|
||||
async with self.get_db_engine().connect() as conn:
|
||||
await conn.run_sync(self.meta.create_all)
|
||||
|
||||
await conn.commit()
|
||||
|
||||
# ======= write initial data =======
|
||||
|
||||
# write initial metadata
|
||||
self.ap.logger.info('Creating initial metadata...')
|
||||
for item in metadata.initial_metadata:
|
||||
# check if the item exists
|
||||
result = await self.execute_async(
|
||||
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == item['key'])
|
||||
)
|
||||
row = result.first()
|
||||
if row is None:
|
||||
await self.execute_async(sqlalchemy.insert(metadata.Metadata).values(item))
|
||||
|
||||
# write default pipeline
|
||||
result = await self.execute_async(sqlalchemy.select(pipeline.LegacyPipeline))
|
||||
if result.first() is None:
|
||||
self.ap.logger.info('Creating default pipeline...')
|
||||
|
||||
pipeline_config = json.load(open('templates/default-pipeline-config.json', 'r', encoding='utf-8'))
|
||||
|
||||
pipeline_data = {
|
||||
'uuid': str(uuid.uuid4()),
|
||||
'for_version': self.ap.ver_mgr.get_current_version(),
|
||||
'stages': pipeline_service.default_stage_order,
|
||||
'is_default': True,
|
||||
'name': 'ChatPipeline',
|
||||
'description': '默认提供的流水线,您配置的机器人、第一个模型将自动绑定到此流水线',
|
||||
'config': pipeline_config,
|
||||
}
|
||||
|
||||
await self.execute_async(sqlalchemy.insert(pipeline.LegacyPipeline).values(pipeline_data))
|
||||
# =================================
|
||||
|
||||
# run migrations
|
||||
database_version = await self.execute_async(
|
||||
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
|
||||
@@ -116,6 +76,49 @@ class PersistenceManager:
|
||||
|
||||
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
|
||||
|
||||
async def create_tables(self):
|
||||
# create tables
|
||||
async with self.get_db_engine().connect() as conn:
|
||||
await conn.run_sync(self.meta.create_all)
|
||||
|
||||
await conn.commit()
|
||||
|
||||
# ======= write initial data =======
|
||||
|
||||
# write initial metadata
|
||||
self.ap.logger.info('Creating initial metadata...')
|
||||
for item in metadata.initial_metadata:
|
||||
# check if the item exists
|
||||
result = await self.execute_async(
|
||||
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == item['key'])
|
||||
)
|
||||
row = result.first()
|
||||
if row is None:
|
||||
await self.execute_async(sqlalchemy.insert(metadata.Metadata).values(item))
|
||||
|
||||
# write default pipeline
|
||||
result = await self.execute_async(sqlalchemy.select(pipeline.LegacyPipeline))
|
||||
default_pipeline_uuid = None
|
||||
if result.first() is None:
|
||||
self.ap.logger.info('Creating default pipeline...')
|
||||
|
||||
pipeline_config = json.load(open('templates/default-pipeline-config.json', 'r', encoding='utf-8'))
|
||||
|
||||
default_pipeline_uuid = str(uuid.uuid4())
|
||||
pipeline_data = {
|
||||
'uuid': default_pipeline_uuid,
|
||||
'for_version': self.ap.ver_mgr.get_current_version(),
|
||||
'stages': pipeline_service.default_stage_order,
|
||||
'is_default': True,
|
||||
'name': 'ChatPipeline',
|
||||
'description': '默认提供的流水线,您配置的机器人、第一个模型将自动绑定到此流水线',
|
||||
'config': pipeline_config,
|
||||
}
|
||||
|
||||
await self.execute_async(sqlalchemy.insert(pipeline.LegacyPipeline).values(pipeline_data))
|
||||
|
||||
# =================================
|
||||
|
||||
async def execute_async(self, *args, **kwargs) -> sqlalchemy.engine.cursor.CursorResult:
|
||||
async with self.get_db_engine().connect() as conn:
|
||||
result = await conn.execute(*args, **kwargs)
|
||||
@@ -125,10 +128,13 @@ class PersistenceManager:
|
||||
def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
|
||||
return self.db.get_engine()
|
||||
|
||||
def serialize_model(self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base) -> dict:
|
||||
def serialize_model(
|
||||
self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base, masked_columns: list[str] = []
|
||||
) -> dict:
|
||||
return {
|
||||
column.name: getattr(data, column.name)
|
||||
if not isinstance(getattr(data, column.name), (datetime.datetime))
|
||||
else getattr(data, column.name).isoformat()
|
||||
for column in model.__table__.columns
|
||||
if column.name not in masked_columns
|
||||
}
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
from .. import migration
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ...entity.persistence import pipeline as persistence_pipeline
|
||||
|
||||
|
||||
@migration.migration_class(2)
|
||||
class DBMigrateCombineQuoteMsgConfig(migration.DBMigration):
|
||||
"""引用消息合并配置"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""升级"""
|
||||
# read all pipelines
|
||||
pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
|
||||
|
||||
for pipeline in pipelines:
|
||||
serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
||||
|
||||
config = serialized_pipeline['config']
|
||||
|
||||
if 'misc' not in config['trigger']:
|
||||
config['trigger']['misc'] = {}
|
||||
|
||||
if 'combine-quote-message' not in config['trigger']['misc']:
|
||||
config['trigger']['misc']['combine-quote-message'] = False
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
|
||||
.where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid'])
|
||||
.values(
|
||||
{
|
||||
'config': config,
|
||||
'for_version': self.ap.ver_mgr.get_current_version(),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""降级"""
|
||||
pass
|
||||
49
pkg/persistence/migrations/dbm003_n8n_config.py
Normal file
49
pkg/persistence/migrations/dbm003_n8n_config.py
Normal file
@@ -0,0 +1,49 @@
|
||||
from .. import migration
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ...entity.persistence import pipeline as persistence_pipeline
|
||||
|
||||
|
||||
@migration.migration_class(3)
|
||||
class DBMigrateN8nConfig(migration.DBMigration):
|
||||
"""N8n配置"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""升级"""
|
||||
# read all pipelines
|
||||
pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
|
||||
|
||||
for pipeline in pipelines:
|
||||
serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
||||
|
||||
config = serialized_pipeline['config']
|
||||
|
||||
if 'n8n-service-api' not in config['ai']:
|
||||
config['ai']['n8n-service-api'] = {
|
||||
'webhook-url': 'http://your-n8n-webhook-url',
|
||||
'auth-type': 'none',
|
||||
'basic-username': '',
|
||||
'basic-password': '',
|
||||
'jwt-secret': '',
|
||||
'jwt-algorithm': 'HS256',
|
||||
'header-name': '',
|
||||
'header-value': '',
|
||||
'timeout': 120,
|
||||
'output-key': 'response',
|
||||
}
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
|
||||
.where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid'])
|
||||
.values(
|
||||
{
|
||||
'config': config,
|
||||
'for_version': self.ap.ver_mgr.get_current_version(),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""降级"""
|
||||
pass
|
||||
20
pkg/persistence/migrations/dbm004_plugin_config.py
Normal file
20
pkg/persistence/migrations/dbm004_plugin_config.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class(4)
|
||||
class DBMigratePluginConfig(migration.DBMigration):
|
||||
"""插件配置"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""升级"""
|
||||
|
||||
if 'plugin' not in self.ap.instance_config.data:
|
||||
self.ap.instance_config.data['plugin'] = {
|
||||
'runtime_ws_url': 'ws://localhost:5400/control/ws',
|
||||
}
|
||||
|
||||
await self.ap.instance_config.dump_config()
|
||||
|
||||
async def downgrade(self):
|
||||
"""降级"""
|
||||
pass
|
||||
25
pkg/persistence/migrations/dbm005_plugin_install_source.py
Normal file
25
pkg/persistence/migrations/dbm005_plugin_install_source.py
Normal file
@@ -0,0 +1,25 @@
|
||||
import sqlalchemy
|
||||
from .. import migration
|
||||
|
||||
|
||||
@migration.migration_class(5)
|
||||
class DBMigratePluginInstallSource(migration.DBMigration):
|
||||
"""插件安装来源"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""升级"""
|
||||
# add new column install_source, use default value 'github', via alter table
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text(
|
||||
"ALTER TABLE plugin_settings ADD COLUMN install_source VARCHAR(255) NOT NULL DEFAULT 'github'"
|
||||
)
|
||||
)
|
||||
|
||||
# add new column install_info, use default value {}, via alter table
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.text("ALTER TABLE plugin_settings ADD COLUMN install_info JSON NOT NULL DEFAULT '{}'")
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""降级"""
|
||||
pass
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
@stage.stage_class('BanSessionCheckStage')
|
||||
@@ -14,7 +14,7 @@ class BanSessionCheckStage(stage.PipelineStage):
|
||||
async def initialize(self, pipeline_config: dict):
|
||||
pass
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
found = False
|
||||
|
||||
mode = query.pipeline_config['trigger']['access-control']['mode']
|
||||
|
||||
@@ -3,12 +3,11 @@ from __future__ import annotations
|
||||
from ...core import app
|
||||
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from . import filter as filter_model, entities as filter_entities
|
||||
from ...provider import entities as llm_entities
|
||||
from ...platform.types import message as platform_message
|
||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
from ...utils import importutil
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
from . import filters
|
||||
|
||||
importutil.import_modules_in_pkg(filters)
|
||||
@@ -58,7 +57,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
async def _pre_process(
|
||||
self,
|
||||
message: str,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
) -> entities.StageProcessResult:
|
||||
"""请求llm前处理消息
|
||||
只要有一个不通过就不放行,只放行 PASS 的消息
|
||||
@@ -66,6 +65,8 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
|
||||
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
if not message.strip():
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
for filter in self.filter_chain:
|
||||
if filter_entities.EnableStage.PRE in filter.enable_stages:
|
||||
@@ -84,14 +85,14 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
|
||||
message = result.replacement
|
||||
|
||||
query.message_chain = platform_message.MessageChain(platform_message.Plain(message))
|
||||
query.message_chain = platform_message.MessageChain([platform_message.Plain(text=message)])
|
||||
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
async def _post_process(
|
||||
self,
|
||||
message: str,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
) -> entities.StageProcessResult:
|
||||
"""请求llm后处理响应
|
||||
只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter
|
||||
@@ -121,7 +122,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理"""
|
||||
if stage_inst_name == 'PreContentFilterStage':
|
||||
contain_non_text = False
|
||||
@@ -140,7 +141,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
return await self._pre_process(str(query.message_chain).strip(), query)
|
||||
elif stage_inst_name == 'PostContentFilterStage':
|
||||
# 仅处理 query.resp_messages[-1].content 是 str 的情况
|
||||
if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(
|
||||
if isinstance(query.resp_messages[-1], provider_message.Message) and isinstance(
|
||||
query.resp_messages[-1].content, str
|
||||
):
|
||||
return await self._post_process(query.resp_messages[-1].content, query)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import enum
|
||||
|
||||
import pydantic.v1 as pydantic
|
||||
import pydantic
|
||||
|
||||
|
||||
class ResultLevel(enum.Enum):
|
||||
|
||||
@@ -3,9 +3,9 @@ from __future__ import annotations
|
||||
import abc
|
||||
import typing
|
||||
|
||||
from ...core import app, entities as core_entities
|
||||
from ...core import app
|
||||
from . import entities
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
preregistered_filters: list[typing.Type[ContentFilter]] = []
|
||||
|
||||
@@ -60,7 +60,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, query: core_entities.Query, message: str = None, image_url=None) -> entities.FilterResult:
|
||||
async def process(self, query: pipeline_query.Query, message: str = None, image_url=None) -> entities.FilterResult:
|
||||
"""处理消息
|
||||
|
||||
分为前后阶段,具体取决于 enable_stages 的值。
|
||||
|
||||
@@ -4,8 +4,7 @@ import aiohttp
|
||||
|
||||
from .. import entities
|
||||
from .. import filter as filter_model
|
||||
from ....core import entities as core_entities
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}'
|
||||
BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
|
||||
@@ -27,7 +26,7 @@ class BaiduCloudExamine(filter_model.ContentFilter):
|
||||
) as resp:
|
||||
return (await resp.json())['access_token']
|
||||
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
BAIDU_EXAMINE_URL.format(await self._get_token()),
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
|
||||
from .. import filter as filter_model
|
||||
from .. import entities
|
||||
from ....core import entities as core_entities
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
@filter_model.filter_class('ban-word-filter')
|
||||
@@ -13,7 +13,7 @@ class BanWordFilter(filter_model.ContentFilter):
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
|
||||
found = False
|
||||
|
||||
for word in self.ap.sensitive_meta.data['words']:
|
||||
|
||||
@@ -3,7 +3,7 @@ import re
|
||||
|
||||
from .. import entities
|
||||
from .. import filter as filter_model
|
||||
from ....core import entities as core_entities
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
@filter_model.filter_class('content-ignore')
|
||||
@@ -16,7 +16,7 @@ class ContentIgnore(filter_model.ContentFilter):
|
||||
entities.EnableStage.PRE,
|
||||
]
|
||||
|
||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
||||
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
|
||||
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
|
||||
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
|
||||
if message.startswith(rule):
|
||||
|
||||
@@ -3,7 +3,10 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import traceback
|
||||
|
||||
from ..core import app, entities
|
||||
from ..core import app
|
||||
from ..core import entities as core_entities
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
class Controller:
|
||||
@@ -22,19 +25,19 @@ class Controller:
|
||||
"""事件处理循环"""
|
||||
try:
|
||||
while True:
|
||||
selected_query: entities.Query = None
|
||||
selected_query: pipeline_query.Query = None
|
||||
|
||||
# 取请求
|
||||
async with self.ap.query_pool:
|
||||
queries: list[entities.Query] = self.ap.query_pool.queries
|
||||
queries: list[pipeline_query.Query] = self.ap.query_pool.queries
|
||||
|
||||
for query in queries:
|
||||
session = await self.ap.sess_mgr.get_session(query)
|
||||
self.ap.logger.debug(f'Checking query {query} session {session}')
|
||||
|
||||
if not session.semaphore.locked():
|
||||
if not session._semaphore.locked():
|
||||
selected_query = query
|
||||
await session.semaphore.acquire()
|
||||
await session._semaphore.acquire()
|
||||
|
||||
break
|
||||
|
||||
@@ -46,21 +49,20 @@ class Controller:
|
||||
|
||||
if selected_query:
|
||||
|
||||
async def _process_query(selected_query: entities.Query):
|
||||
async def _process_query(selected_query: pipeline_query.Query):
|
||||
async with self.semaphore: # 总并发上限
|
||||
# find pipeline
|
||||
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
|
||||
# Like aiocqhttp, once a client is connected, even the adapter was updated and restarted, the existing client connection will not be affected.
|
||||
bot = await self.ap.platform_mgr.get_bot_by_uuid(selected_query.bot_uuid)
|
||||
if bot:
|
||||
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(
|
||||
bot.bot_entity.use_pipeline_uuid
|
||||
)
|
||||
pipeline_uuid = selected_query.pipeline_uuid
|
||||
|
||||
if pipeline_uuid:
|
||||
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(pipeline_uuid)
|
||||
if pipeline:
|
||||
await pipeline.run(selected_query)
|
||||
|
||||
async with self.ap.query_pool:
|
||||
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
|
||||
(await self.ap.sess_mgr.get_session(selected_query))._semaphore.release()
|
||||
# 通知其他协程,有新的请求可以处理了
|
||||
self.ap.query_pool.condition.notify_all()
|
||||
|
||||
@@ -69,8 +71,8 @@ class Controller:
|
||||
kind='query',
|
||||
name=f'query-{selected_query.query_id}',
|
||||
scopes=[
|
||||
entities.LifecycleControlScope.APPLICATION,
|
||||
entities.LifecycleControlScope.PLATFORM,
|
||||
core_entities.LifecycleControlScope.APPLICATION,
|
||||
core_entities.LifecycleControlScope.PLATFORM,
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@@ -3,10 +3,10 @@ from __future__ import annotations
|
||||
import enum
|
||||
import typing
|
||||
|
||||
import pydantic.v1 as pydantic
|
||||
from ..platform.types import message as platform_message
|
||||
import pydantic
|
||||
|
||||
from ..core import entities
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
|
||||
|
||||
class ResultType(enum.Enum):
|
||||
@@ -20,7 +20,7 @@ class ResultType(enum.Enum):
|
||||
class StageProcessResult(pydantic.BaseModel):
|
||||
result_type: ResultType
|
||||
|
||||
new_query: entities.Query
|
||||
new_query: pipeline_query.Query
|
||||
|
||||
user_notice: typing.Optional[
|
||||
typing.Union[
|
||||
|
||||
@@ -5,10 +5,9 @@ import traceback
|
||||
|
||||
from . import strategy
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...platform.types import message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
from ...utils import importutil
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
from . import strategies
|
||||
|
||||
importutil.import_modules_in_pkg(strategies)
|
||||
@@ -67,7 +66,7 @@ class LongTextProcessStage(stage.PipelineStage):
|
||||
|
||||
await self.strategy_impl.initialize()
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
# 检查是否包含非 Plain 组件
|
||||
contains_non_plain = False
|
||||
|
||||
|
||||
@@ -3,9 +3,9 @@ from __future__ import annotations
|
||||
|
||||
|
||||
from .. import strategy as strategy_model
|
||||
from ....core import entities as core_entities
|
||||
from ....platform.types import message as platform_message
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
|
||||
ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay
|
||||
Forward = platform_message.Forward
|
||||
@@ -13,7 +13,7 @@ Forward = platform_message.Forward
|
||||
|
||||
@strategy_model.strategy_class('forward')
|
||||
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
|
||||
display = ForwardMessageDiaplay(
|
||||
title='群聊的聊天记录',
|
||||
brief='[聊天记录]',
|
||||
|
||||
@@ -8,10 +8,10 @@ import re
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
|
||||
import functools
|
||||
from ....platform.types import message as platform_message
|
||||
|
||||
from .. import strategy as strategy_model
|
||||
from ....core import entities as core_entities
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
|
||||
|
||||
@strategy_model.strategy_class('image')
|
||||
@@ -20,14 +20,14 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
pass
|
||||
|
||||
@functools.lru_cache(maxsize=16)
|
||||
def get_font(self, query: core_entities.Query):
|
||||
def get_font(self, font_path: str):
|
||||
return ImageFont.truetype(
|
||||
query.pipeline_config['output']['long-text-processing']['font-path'],
|
||||
font_path,
|
||||
32,
|
||||
encoding='utf-8',
|
||||
)
|
||||
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
|
||||
img_path = self.text_to_image(
|
||||
text_str=message,
|
||||
save_as='temp/{}.png'.format(int(time.time())),
|
||||
@@ -131,7 +131,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
text_str: str,
|
||||
save_as='temp.png',
|
||||
width=800,
|
||||
query: core_entities.Query = None,
|
||||
query: pipeline_query.Query = None,
|
||||
):
|
||||
text_str = text_str.replace('\t', ' ')
|
||||
|
||||
@@ -146,7 +146,9 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
self.ap.logger.debug('lines: {}, text_width: {}'.format(lines, text_width))
|
||||
for line in lines:
|
||||
# 如果长了就分割
|
||||
line_width = self.get_font(query).getlength(line)
|
||||
line_width = self.get_font(query.pipeline_config['output']['long-text-processing']['font-path']).getlength(
|
||||
line
|
||||
)
|
||||
self.ap.logger.debug('line_width: {}'.format(line_width))
|
||||
if line_width < text_width:
|
||||
final_lines.append(line)
|
||||
@@ -167,7 +169,9 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
|
||||
final_lines.append(rest_text[:point])
|
||||
rest_text = rest_text[point:]
|
||||
line_width = self.text_render_font.getlength(rest_text)
|
||||
line_width = self.get_font(
|
||||
query.pipeline_config['output']['long-text-processing']['font-path']
|
||||
).getlength(rest_text)
|
||||
if line_width < text_width:
|
||||
final_lines.append(rest_text)
|
||||
break
|
||||
@@ -187,7 +191,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
||||
(offset_x, offset_y + 35 * line_number),
|
||||
final_line,
|
||||
fill=(0, 0, 0),
|
||||
font=self.text_render_font,
|
||||
font=self.get_font(query.pipeline_config['output']['long-text-processing']['font-path']),
|
||||
)
|
||||
# 遍历此行,检查是否有emoji
|
||||
idx_in_line = 0
|
||||
|
||||
@@ -4,8 +4,9 @@ import typing
|
||||
|
||||
|
||||
from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
from ...platform.types import message as platform_message
|
||||
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
|
||||
@@ -49,7 +50,7 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
||||
async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
|
||||
"""处理长文本
|
||||
|
||||
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from . import truncator
|
||||
from ...utils import importutil
|
||||
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
from . import truncators
|
||||
|
||||
importutil.import_modules_in_pkg(truncators)
|
||||
@@ -29,7 +28,7 @@ class ConversationMessageTruncator(stage.PipelineStage):
|
||||
else:
|
||||
raise ValueError(f'未知的截断器: {use_method}')
|
||||
|
||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||
"""处理"""
|
||||
query = await self.trun.truncate(query)
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
||||
import typing
|
||||
import abc
|
||||
|
||||
from ...core import entities as core_entities, app
|
||||
|
||||
from ...core import app
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
preregistered_truncators: list[typing.Type[Truncator]] = []
|
||||
|
||||
@@ -47,7 +47,7 @@ class Truncator(abc.ABC):
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
|
||||
async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
|
||||
"""截断
|
||||
|
||||
一般只需要操作query.messages,也可以扩展操作query.prompt, query.user_message。
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from .. import truncator
|
||||
from ....core import entities as core_entities
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
@truncator.truncator_class('round')
|
||||
class RoundTruncator(truncator.Truncator):
|
||||
"""前文回合数阶段器"""
|
||||
|
||||
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
|
||||
async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
|
||||
"""截断"""
|
||||
max_round = query.pipeline_config['ai']['local-agent']['max-round']
|
||||
|
||||
|
||||
@@ -5,14 +5,18 @@ import traceback
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ..core import app, entities
|
||||
from ..core import app
|
||||
from . import entities as pipeline_entities
|
||||
from ..entity.persistence import pipeline as persistence_pipeline
|
||||
from . import stage
|
||||
from ..platform.types import message as platform_message, events as platform_events
|
||||
from ..plugin import events
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.events as events
|
||||
from ..utils import importutil
|
||||
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
from . import (
|
||||
resprule,
|
||||
bansess,
|
||||
@@ -75,17 +79,17 @@ class RuntimePipeline:
|
||||
self.pipeline_entity = pipeline_entity
|
||||
self.stage_containers = stage_containers
|
||||
|
||||
async def run(self, query: entities.Query):
|
||||
async def run(self, query: pipeline_query.Query):
|
||||
query.pipeline_config = self.pipeline_entity.config
|
||||
await self.process_query(query)
|
||||
|
||||
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
|
||||
async def _check_output(self, query: pipeline_query.Query, result: pipeline_entities.StageProcessResult):
|
||||
"""检查输出"""
|
||||
if result.user_notice:
|
||||
# 处理str类型
|
||||
|
||||
if isinstance(result.user_notice, str):
|
||||
result.user_notice = platform_message.MessageChain(platform_message.Plain(result.user_notice))
|
||||
result.user_notice = platform_message.MessageChain([platform_message.Plain(text=result.user_notice)])
|
||||
elif isinstance(result.user_notice, list):
|
||||
result.user_notice = platform_message.MessageChain(*result.user_notice)
|
||||
|
||||
@@ -109,7 +113,7 @@ class RuntimePipeline:
|
||||
async def _execute_from_stage(
|
||||
self,
|
||||
stage_index: int,
|
||||
query: entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
):
|
||||
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
|
||||
|
||||
@@ -136,7 +140,7 @@ class RuntimePipeline:
|
||||
while i < len(self.stage_containers):
|
||||
stage_container = self.stage_containers[i]
|
||||
|
||||
query.current_stage = stage_container # 标记到 Query 对象里
|
||||
query.current_stage_name = stage_container.inst_name # 标记到 Query 对象里
|
||||
|
||||
result = stage_container.inst.process(query, stage_container.inst_name)
|
||||
|
||||
@@ -169,26 +173,26 @@ class RuntimePipeline:
|
||||
|
||||
i += 1
|
||||
|
||||
async def process_query(self, query: entities.Query):
|
||||
async def process_query(self, query: pipeline_query.Query):
|
||||
"""处理请求"""
|
||||
try:
|
||||
# ======== 触发 MessageReceived 事件 ========
|
||||
event_type = (
|
||||
events.PersonMessageReceived
|
||||
if query.launcher_type == entities.LauncherTypes.PERSON
|
||||
if query.launcher_type == provider_session.LauncherTypes.PERSON
|
||||
else events.GroupMessageReceived
|
||||
)
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=event_type(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
message_chain=query.message_chain,
|
||||
query=query,
|
||||
)
|
||||
event_obj = event_type(
|
||||
query=query,
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
message_chain=query.message_chain,
|
||||
)
|
||||
|
||||
event_ctx = await self.ap.plugin_connector.emit_event(event_obj)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
return
|
||||
|
||||
@@ -196,11 +200,12 @@ class RuntimePipeline:
|
||||
|
||||
await self._execute_from_stage(0, query)
|
||||
except Exception as e:
|
||||
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
|
||||
inst_name = query.current_stage_name if query.current_stage_name else 'unknown'
|
||||
self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}')
|
||||
self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
|
||||
finally:
|
||||
self.ap.logger.debug(f'Query {query} processed')
|
||||
del self.ap.query_pool.cached_queries[query.query_id]
|
||||
|
||||
|
||||
class PipelineManager:
|
||||
|
||||
@@ -3,10 +3,11 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import typing
|
||||
|
||||
from ..core import entities
|
||||
from ..platform import adapter as msadapter
|
||||
from ..platform.types import message as platform_message
|
||||
from ..platform.types import events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||
|
||||
|
||||
class QueryPool:
|
||||
@@ -16,7 +17,10 @@ class QueryPool:
|
||||
|
||||
pool_lock: asyncio.Lock
|
||||
|
||||
queries: list[entities.Query]
|
||||
queries: list[pipeline_query.Query]
|
||||
|
||||
cached_queries: dict[int, pipeline_query.Query]
|
||||
"""Cached queries, used for plugin backward api call, will be removed after the query completely processed"""
|
||||
|
||||
condition: asyncio.Condition
|
||||
|
||||
@@ -24,32 +28,38 @@ class QueryPool:
|
||||
self.query_id_counter = 0
|
||||
self.pool_lock = asyncio.Lock()
|
||||
self.queries = []
|
||||
self.cached_queries = {}
|
||||
self.condition = asyncio.Condition(self.pool_lock)
|
||||
|
||||
async def add_query(
|
||||
self,
|
||||
bot_uuid: str,
|
||||
launcher_type: entities.LauncherTypes,
|
||||
launcher_type: provider_session.LauncherTypes,
|
||||
launcher_id: typing.Union[int, str],
|
||||
sender_id: typing.Union[int, str],
|
||||
message_event: platform_events.MessageEvent,
|
||||
message_chain: platform_message.MessageChain,
|
||||
adapter: msadapter.MessagePlatformAdapter,
|
||||
) -> entities.Query:
|
||||
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
|
||||
pipeline_uuid: typing.Optional[str] = None,
|
||||
) -> pipeline_query.Query:
|
||||
async with self.condition:
|
||||
query = entities.Query(
|
||||
query_id = self.query_id_counter
|
||||
query = pipeline_query.Query(
|
||||
bot_uuid=bot_uuid,
|
||||
query_id=self.query_id_counter,
|
||||
query_id=query_id,
|
||||
launcher_type=launcher_type,
|
||||
launcher_id=launcher_id,
|
||||
sender_id=sender_id,
|
||||
message_event=message_event,
|
||||
message_chain=message_chain,
|
||||
variables={},
|
||||
resp_messages=[],
|
||||
resp_message_chain=[],
|
||||
adapter=adapter,
|
||||
pipeline_uuid=pipeline_uuid,
|
||||
)
|
||||
self.queries.append(query)
|
||||
self.cached_queries[query_id] = query
|
||||
self.query_id_counter += 1
|
||||
self.condition.notify_all()
|
||||
|
||||
|
||||
@@ -3,10 +3,10 @@ from __future__ import annotations
|
||||
import datetime
|
||||
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
from ...provider import entities as llm_entities
|
||||
from ...plugin import events
|
||||
from ...platform.types import message as platform_message
|
||||
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||
import langbot_plugin.api.entities.events as events
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
@stage.stage_class('PreProcessor')
|
||||
@@ -26,7 +26,7 @@ class PreProcessor(stage.PipelineStage):
|
||||
|
||||
async def process(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
stage_inst_name: str,
|
||||
) -> entities.StageProcessResult:
|
||||
"""处理"""
|
||||
@@ -45,33 +45,35 @@ class PreProcessor(stage.PipelineStage):
|
||||
query,
|
||||
session,
|
||||
query.pipeline_config['ai']['local-agent']['prompt'],
|
||||
query.pipeline_uuid,
|
||||
query.bot_uuid,
|
||||
)
|
||||
|
||||
conversation.use_llm_model = llm_model
|
||||
|
||||
# 设置query
|
||||
query.session = session
|
||||
query.prompt = conversation.prompt.copy()
|
||||
query.messages = conversation.messages.copy()
|
||||
|
||||
query.use_llm_model = llm_model
|
||||
query.use_llm_model_uuid = llm_model.model_entity.uuid
|
||||
|
||||
if selected_runner == 'local-agent':
|
||||
query.use_funcs = (
|
||||
conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('tool_call') else None
|
||||
)
|
||||
query.use_funcs = []
|
||||
|
||||
query.variables = {
|
||||
if llm_model.model_entity.abilities.__contains__('func_call'):
|
||||
query.use_funcs = await self.ap.tool_mgr.get_all_tools()
|
||||
|
||||
variables = {
|
||||
'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
'conversation_id': conversation.uuid,
|
||||
'msg_create_time': (
|
||||
int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp())
|
||||
),
|
||||
}
|
||||
query.variables.update(variables)
|
||||
|
||||
# Check if this model supports vision, if not, remove all images
|
||||
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
|
||||
if selected_runner == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'):
|
||||
if selected_runner == 'local-agent' and not llm_model.model_entity.abilities.__contains__('vision'):
|
||||
for msg in query.messages:
|
||||
if isinstance(msg.content, list):
|
||||
for me in msg.content:
|
||||
@@ -81,32 +83,39 @@ class PreProcessor(stage.PipelineStage):
|
||||
content_list = []
|
||||
|
||||
plain_text = ''
|
||||
qoute_msg = query.pipeline_config['trigger'].get('misc', '').get('combine-quote-message')
|
||||
|
||||
for me in query.message_chain:
|
||||
if isinstance(me, platform_message.Plain):
|
||||
content_list.append(llm_entities.ContentElement.from_text(me.text))
|
||||
content_list.append(provider_message.ContentElement.from_text(me.text))
|
||||
plain_text += me.text
|
||||
elif isinstance(me, platform_message.Image):
|
||||
if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__(
|
||||
'vision'
|
||||
):
|
||||
if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
|
||||
if me.base64 is not None:
|
||||
content_list.append(llm_entities.ContentElement.from_image_base64(me.base64))
|
||||
content_list.append(provider_message.ContentElement.from_image_base64(me.base64))
|
||||
elif isinstance(me, platform_message.Quote) and qoute_msg:
|
||||
for msg in me.origin:
|
||||
if isinstance(msg, platform_message.Plain):
|
||||
content_list.append(provider_message.ContentElement.from_text(msg.text))
|
||||
elif isinstance(msg, platform_message.Image):
|
||||
if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
|
||||
if msg.base64 is not None:
|
||||
content_list.append(provider_message.ContentElement.from_image_base64(msg.base64))
|
||||
|
||||
query.variables['user_message_text'] = plain_text
|
||||
|
||||
query.user_message = llm_entities.Message(role='user', content=content_list)
|
||||
query.user_message = provider_message.Message(role='user', content=content_list)
|
||||
# =========== 触发事件 PromptPreProcessing
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=events.PromptPreProcessing(
|
||||
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
default_prompt=query.prompt.messages,
|
||||
prompt=query.messages,
|
||||
query=query,
|
||||
)
|
||||
event = events.PromptPreProcessing(
|
||||
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
default_prompt=query.prompt.messages,
|
||||
prompt=query.messages,
|
||||
query=query,
|
||||
)
|
||||
|
||||
event_ctx = await self.ap.plugin_connector.emit_event(event)
|
||||
|
||||
query.prompt.messages = event_ctx.event.default_prompt
|
||||
query.messages = event_ctx.event.prompt
|
||||
|
||||
|
||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
||||
import abc
|
||||
|
||||
from ...core import app
|
||||
from ...core import entities as core_entities
|
||||
from .. import entities
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
class MessageHandler(metaclass=abc.ABCMeta):
|
||||
@@ -19,7 +19,7 @@ class MessageHandler(metaclass=abc.ABCMeta):
|
||||
@abc.abstractmethod
|
||||
async def handle(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
) -> entities.StageProcessResult:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -6,13 +6,15 @@ import traceback
|
||||
|
||||
from .. import handler
|
||||
from ... import entities
|
||||
from ....core import entities as core_entities
|
||||
from ....provider import runner as runner_module
|
||||
from ....plugin import events
|
||||
|
||||
from ....platform.types import message as platform_message
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import langbot_plugin.api.entities.events as events
|
||||
from ....utils import importutil
|
||||
from ....provider import runners
|
||||
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||
|
||||
|
||||
importutil.import_modules_in_pkg(runners)
|
||||
|
||||
@@ -20,7 +22,7 @@ importutil.import_modules_in_pkg(runners)
|
||||
class ChatMessageHandler(handler.MessageHandler):
|
||||
async def handle(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
query: pipeline_query.Query,
|
||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||
"""处理"""
|
||||
# 调API
|
||||
@@ -29,20 +31,20 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
# 触发插件事件
|
||||
event_class = (
|
||||
events.PersonNormalMessageReceived
|
||||
if query.launcher_type == core_entities.LauncherTypes.PERSON
|
||||
if query.launcher_type == provider_session.LauncherTypes.PERSON
|
||||
else events.GroupNormalMessageReceived
|
||||
)
|
||||
|
||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
||||
event=event_class(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
text_message=str(query.message_chain),
|
||||
query=query,
|
||||
)
|
||||
event = event_class(
|
||||
launcher_type=query.launcher_type.value,
|
||||
launcher_id=query.launcher_id,
|
||||
sender_id=query.sender_id,
|
||||
text_message=str(query.message_chain),
|
||||
query=query,
|
||||
)
|
||||
|
||||
event_ctx = await self.ap.plugin_connector.emit_event(event)
|
||||
|
||||
if event_ctx.is_prevented_default():
|
||||
if event_ctx.event.reply is not None:
|
||||
mc = platform_message.MessageChain(event_ctx.event.reply)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user