Compare commits

..

53 Commits

Author SHA1 Message Date
Junyan Qin
d3a147bbdd chore: bump version 4.3.0b3 2025-08-23 20:08:29 +08:00
Junyan Qin
8eb1b8759b chore: bump version to '4.3.0b2' 2025-08-23 20:06:19 +08:00
Junyan Qin
0155d3b0b9 fix: conflict in table plugin_settings 2025-08-23 20:05:24 +08:00
Junyan Qin
e47a5b4e0d chore: bump langbot_plugin version 2025-08-23 17:12:29 +08:00
Junyan Qin
4012310d99 chore: bump version 4.3.0b1 2025-08-21 10:49:51 +08:00
Junyan Qin
9e9bc88473 chore: remove plugin reorder functionality 2025-08-21 10:47:53 +08:00
Junyan Qin
53ade384eb feat: bump version of langbot-plugin 2025-08-20 23:26:32 +08:00
Junyan Qin
8b2480ad3b feat: setting plugin config 2025-08-17 21:01:43 +08:00
Junyan Qin
b176959836 feat: plugin deletion and upgrade 2025-08-17 18:07:51 +08:00
Junyan Qin
a0c42a5f6e feat: plugin operations 2025-08-17 16:51:44 +08:00
Junyan Qin
17d997c88e fix: i18n fallback 2025-08-17 11:43:38 +08:00
Junyan Qin
0ea7609ff1 perf: frontend 2025-08-16 23:23:24 +08:00
Junyan Qin
28d4b1dd61 feat: marketplace page 2025-08-16 18:05:33 +08:00
Junyan Qin
5179b3e53a feat: trace plugin installation 2025-08-16 15:42:49 +08:00
Junyan Qin
288b294148 feat: plugin installation webui 2025-08-15 22:05:39 +08:00
Junyan Qin
b464d238c5 feat: plugin installation 2025-08-15 21:30:26 +08:00
Junyan Qin
e1a78e8ff9 feat: tag debugging plugins in webui 2025-08-15 19:11:49 +08:00
Junyan Qin
2b8eb5f01c fix: bot switching 2025-08-15 17:02:00 +08:00
Junyan Qin
bf2bc70794 feat: refactor webui httpclient 2025-08-14 23:55:14 +08:00
Junyan Qin
ebe0b68e8f feat: set cloud_service_url 2025-08-14 23:42:57 +08:00
Junyan Qin
39c50d3c12 feat: get_bot_info api 2025-08-13 20:54:43 +08:00
Junyan Qin
621f1301b3 fix: message chain init 2025-08-11 17:24:57 +08:00
Junyan Qin
0b60ef0d06 chore: bump langbot-plugin version to 0.1.1a1 2025-08-09 21:06:31 +08:00
Junyan Qin
41650b585a perf: dispose process 2025-08-02 23:54:06 +08:00
Junyan Qin
f5b893cfe0 feat: kill runtime process when exit in stdio mode 2025-07-16 22:43:39 +08:00
Junyan Qin
e0abd19636 feat: get plugin info 2025-07-13 22:14:22 +08:00
Junyan Qin
4380041c7f feat(ui): list plugins 2025-07-13 22:03:47 +08:00
Junyan Qin
65814a4644 feat: binary storage api 2025-07-13 21:39:33 +08:00
Junyan Qin
7237294008 perf: longer timeout for emit_event 2025-07-13 20:48:15 +08:00
Junyan Qin
214bc8ada9 feat: backward call apis 2025-07-13 20:45:45 +08:00
Junyan Qin
6a1de889b4 refactor: switch llm_entities to plugin sdk 2025-07-13 20:30:17 +08:00
Junyan Qin
4a319b2b20 feat: query-based apis 2025-07-13 18:41:04 +08:00
Junyan Qin
9f269d1614 feat: get bot uuid api 2025-07-13 17:44:20 +08:00
Junyan Qin
4b57771eb1 feat: reply_message api 2025-07-13 16:31:25 +08:00
Junyan Qin
5922be7e15 feat: command execution via plugin 2025-07-13 10:26:48 +08:00
Junyan Qin
10a44c70b6 feat: switch command entities to sdk 2025-07-10 10:51:36 +08:00
Junyan Qin
5b044a1917 feat: add Tool component 2025-07-06 21:03:33 +08:00
Junyan Qin
a60aa6f644 feat: runtime reconnecting 2025-07-02 22:20:20 +08:00
Junyan Qin
1a10b40b17 refactor: use emit_event from connector 2025-07-02 12:46:30 +08:00
Junyan Qin
e2124054bf feat: switch all event emitting logic to new method 2025-07-02 11:58:10 +08:00
Junyan Qin
ee3da8aa17 feat: adapt more events 2025-07-02 11:04:03 +08:00
Junyan Qin
c246470b37 feat: minor changes adapt to event emitting 2025-07-01 22:44:46 +08:00
Junyan Qin
f474e42b79 fix: serialization bug in events emitting 2025-06-30 21:49:59 +08:00
Junyan Qin
5553a86ac8 feat: preliminary migration of events entities 2025-06-30 21:49:59 +08:00
Junyan Qin
01613b2f0d chore: remove adapter meta manifest from components.yaml 2025-06-30 21:49:59 +08:00
Junyan Qin
a177786063 feat: switch message platform adapters to sdk 2025-06-30 21:49:59 +08:00
Junyan Qin
62b2884011 chore: delete Query class 2025-06-30 21:47:40 +08:00
Junyan Qin
6b782f8761 feat: switch Query to langbot-plugin definition 2025-06-30 21:47:40 +08:00
Junyan Qin
0c2560cafb feat: switch tool entities and format 2025-06-30 21:47:40 +08:00
Junyan Qin
c5eeab2fd0 feat: listing plugins 2025-06-30 21:43:43 +08:00
Junyan Qin
6f2fd72af6 feat(plugin): basic communication 2025-06-30 21:43:43 +08:00
Junyan Qin
2d06f1cadb feat: connector for plugin runtime 2025-06-30 21:43:43 +08:00
Junyan Qin
af493c117c deps: add langbot-plugin 2025-06-30 21:43:42 +08:00
203 changed files with 5449 additions and 6541 deletions

View File

@@ -119,7 +119,6 @@ docker compose up -d
| [Anthropic](https://www.anthropic.com/) | ✅ | | | [Anthropic](https://www.anthropic.com/) | ✅ | |
| [xAI](https://x.ai/) | ✅ | | | [xAI](https://x.ai/) | ✅ | |
| [智谱AI](https://open.bigmodel.cn/) | ✅ | | | [智谱AI](https://open.bigmodel.cn/) | ✅ | |
| [优云智算](https://www.compshare.cn/) | ✅ | 大模型和 GPU 资源平台 |
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | ✅ | 大模型和 GPU 资源平台 | | [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | ✅ | 大模型和 GPU 资源平台 |
| [302 AI](https://share.302.ai/SuTG99) | ✅ | 大模型聚合平台 | | [302 AI](https://share.302.ai/SuTG99) | ✅ | 大模型聚合平台 |
| [Google Gemini](https://aistudio.google.com/prompts/new_chat) | ✅ | | | [Google Gemini](https://aistudio.google.com/prompts/new_chat) | ✅ | |

View File

@@ -116,7 +116,6 @@ Directly use the released version to run, see the [Manual Deployment](https://do
| [Anthropic](https://www.anthropic.com/) | ✅ | | | [Anthropic](https://www.anthropic.com/) | ✅ | |
| [xAI](https://x.ai/) | ✅ | | | [xAI](https://x.ai/) | ✅ | |
| [Zhipu AI](https://open.bigmodel.cn/) | ✅ | | | [Zhipu AI](https://open.bigmodel.cn/) | ✅ | |
| [CompShare](https://www.compshare.cn/) | ✅ | LLM and GPU resource platform |
| [Dify](https://dify.ai) | ✅ | LLMOps platform | | [Dify](https://dify.ai) | ✅ | LLMOps platform |
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | ✅ | LLM and GPU resource platform | | [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | ✅ | LLM and GPU resource platform |
| [302 AI](https://share.302.ai/SuTG99) | ✅ | LLM gateway(MaaS) | | [302 AI](https://share.302.ai/SuTG99) | ✅ | LLM gateway(MaaS) |

View File

@@ -115,7 +115,6 @@ LangBotはBTPanelにリストされています。BTPanelをインストール
| [Anthropic](https://www.anthropic.com/) | ✅ | | | [Anthropic](https://www.anthropic.com/) | ✅ | |
| [xAI](https://x.ai/) | ✅ | | | [xAI](https://x.ai/) | ✅ | |
| [Zhipu AI](https://open.bigmodel.cn/) | ✅ | | | [Zhipu AI](https://open.bigmodel.cn/) | ✅ | |
| [CompShare](https://www.compshare.cn/) | ✅ | 大模型とGPUリソースプラットフォーム |
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | ✅ | 大模型とGPUリソースプラットフォーム | | [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | ✅ | 大模型とGPUリソースプラットフォーム |
| [302 AI](https://share.302.ai/SuTG99) | ✅ | LLMゲートウェイ(MaaS) | | [302 AI](https://share.302.ai/SuTG99) | ✅ | LLMゲートウェイ(MaaS) |
| [Google Gemini](https://aistudio.google.com/prompts/new_chat) | ✅ | | | [Google Gemini](https://aistudio.google.com/prompts/new_chat) | ✅ | |

View File

@@ -9,7 +9,6 @@ spec:
components: components:
ComponentTemplate: ComponentTemplate:
fromFiles: fromFiles:
- pkg/platform/adapter.yaml
- pkg/provider/modelmgr/requester.yaml - pkg/provider/modelmgr/requester.yaml
MessagePlatformAdapter: MessagePlatformAdapter:
fromDirs: fromDirs:

View File

@@ -3,7 +3,7 @@ from quart import request
import httpx import httpx
from quart import Quart from quart import Quart
from typing import Callable, Dict, Any from typing import Callable, Dict, Any
from pkg.platform.types import events as platform_events import langbot_plugin.api.entities.builtin.platform.events as platform_events
from .qqofficialevent import QQOfficialEvent from .qqofficialevent import QQOfficialEvent
import json import json
import traceback import traceback
@@ -104,7 +104,7 @@ class QQOfficialClient:
return {'code': 0, 'message': 'success'} return {'code': 0, 'message': 'success'}
except Exception as e: except Exception as e:
await self.logger.error(f"Error in handle_callback_request: {traceback.format_exc()}") await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
return {'error': str(e)}, 400 return {'error': str(e)}, 400
async def run_task(self, host: str, port: int, *args, **kwargs): async def run_task(self, host: str, port: int, *args, **kwargs):
@@ -168,7 +168,6 @@ class QQOfficialClient:
if not await self.check_access_token(): if not await self.check_access_token():
await self.get_access_token() await self.get_access_token()
url = self.base_url + '/v2/users/' + user_openid + '/messages' url = self.base_url + '/v2/users/' + user_openid + '/messages'
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
headers = { headers = {
@@ -193,7 +192,6 @@ class QQOfficialClient:
if not await self.check_access_token(): if not await self.check_access_token():
await self.get_access_token() await self.get_access_token()
url = self.base_url + '/v2/groups/' + group_openid + '/messages' url = self.base_url + '/v2/groups/' + group_openid + '/messages'
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
headers = { headers = {
@@ -209,7 +207,7 @@ class QQOfficialClient:
if response.status_code == 200: if response.status_code == 200:
return return
else: else:
await self.logger.error(f"发送群聊消息失败:{response.json()}") await self.logger.error(f'发送群聊消息失败:{response.json()}')
raise Exception(response.read().decode()) raise Exception(response.read().decode())
async def send_channle_group_text_msg(self, channel_id: str, content: str, msg_id: str): async def send_channle_group_text_msg(self, channel_id: str, content: str, msg_id: str):
@@ -217,7 +215,6 @@ class QQOfficialClient:
if not await self.check_access_token(): if not await self.check_access_token():
await self.get_access_token() await self.get_access_token()
url = self.base_url + '/channels/' + channel_id + '/messages' url = self.base_url + '/channels/' + channel_id + '/messages'
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
headers = { headers = {
@@ -240,7 +237,6 @@ class QQOfficialClient:
"""发送频道私聊消息""" """发送频道私聊消息"""
if not await self.check_access_token(): if not await self.check_access_token():
await self.get_access_token() await self.get_access_token()
url = self.base_url + '/dms/' + guild_id + '/messages' url = self.base_url + '/dms/' + guild_id + '/messages'
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:

View File

@@ -4,7 +4,7 @@ from quart import Quart, jsonify, request
from slack_sdk.web.async_client import AsyncWebClient from slack_sdk.web.async_client import AsyncWebClient
from .slackevent import SlackEvent from .slackevent import SlackEvent
from typing import Callable from typing import Callable
from pkg.platform.types import events as platform_events import langbot_plugin.api.entities.builtin.platform.events as platform_events
class SlackClient: class SlackClient:
@@ -34,7 +34,6 @@ class SlackClient:
if self.bot_user_id and bot_user_id == self.bot_user_id: if self.bot_user_id and bot_user_id == self.bot_user_id:
return jsonify({'status': 'ok'}) return jsonify({'status': 'ok'})
# 处理私信 # 处理私信
if data and data.get('event', {}).get('channel_type') in ['im']: if data and data.get('event', {}).get('channel_type') in ['im']:
@@ -52,7 +51,7 @@ class SlackClient:
return jsonify({'status': 'ok'}) return jsonify({'status': 'ok'})
except Exception as e: except Exception as e:
await self.logger.error(f"Error in handle_callback_request: {traceback.format_exc()}") await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
raise (e) raise (e)
async def _handle_message(self, event: SlackEvent): async def _handle_message(self, event: SlackEvent):
@@ -82,7 +81,7 @@ class SlackClient:
self.bot_user_id = response['message']['bot_id'] self.bot_user_id = response['message']['bot_id']
return return
except Exception as e: except Exception as e:
await self.logger.error(f"Error in send_message: {e}") await self.logger.error(f'Error in send_message: {e}')
raise e raise e
async def send_message_to_one(self, text: str, user_id: str): async def send_message_to_one(self, text: str, user_id: str):
@@ -93,7 +92,7 @@ class SlackClient:
return return
except Exception as e: except Exception as e:
await self.logger.error(f"Error in send_message: {traceback.format_exc()}") await self.logger.error(f'Error in send_message: {traceback.format_exc()}')
raise e raise e
async def run_task(self, host: str, port: int, *args, **kwargs): async def run_task(self, host: str, port: int, *args, **kwargs):

View File

@@ -1 +1,4 @@
from .client import WeChatPadClient from .client import WeChatPadClient
__all__ = ['WeChatPadClient']

View File

@@ -1,4 +1,4 @@
from libs.wechatpad_api.util.http_util import async_request, post_json from libs.wechatpad_api.util.http_util import post_json
class ChatRoomApi: class ChatRoomApi:
@@ -7,8 +7,6 @@ class ChatRoomApi:
self.token = token self.token = token
def get_chatroom_member_detail(self, chatroom_name): def get_chatroom_member_detail(self, chatroom_name):
params = { params = {'ChatRoomName': chatroom_name}
"ChatRoomName": chatroom_name
}
url = self.base_url + '/group/GetChatroomMemberDetail' url = self.base_url + '/group/GetChatroomMemberDetail'
return post_json(url, token=self.token, data=params) return post_json(url, token=self.token, data=params)

View File

@@ -1,32 +1,23 @@
from libs.wechatpad_api.util.http_util import async_request, post_json from libs.wechatpad_api.util.http_util import post_json
import httpx import httpx
import base64 import base64
class DownloadApi: class DownloadApi:
def __init__(self, base_url, token): def __init__(self, base_url, token):
self.base_url = base_url self.base_url = base_url
self.token = token self.token = token
def send_download(self, aeskey, file_type, file_url): def send_download(self, aeskey, file_type, file_url):
json_data = { json_data = {'AesKey': aeskey, 'FileType': file_type, 'FileURL': file_url}
"AesKey": aeskey, url = self.base_url + '/message/SendCdnDownload'
"FileType": file_type,
"FileURL": file_url
}
url = self.base_url + "/message/SendCdnDownload"
return post_json(url, token=self.token, data=json_data) return post_json(url, token=self.token, data=json_data)
def get_msg_voice(self,buf_id, length, new_msgid): def get_msg_voice(self, buf_id, length, new_msgid):
json_data = { json_data = {'Bufid': buf_id, 'Length': length, 'NewMsgId': new_msgid, 'ToUserName': ''}
"Bufid": buf_id, url = self.base_url + '/message/GetMsgVoice'
"Length": length,
"NewMsgId": new_msgid,
"ToUserName": ""
}
url = self.base_url + "/message/GetMsgVoice"
return post_json(url, token=self.token, data=json_data) return post_json(url, token=self.token, data=json_data)
async def download_url_to_base64(self, download_url): async def download_url_to_base64(self, download_url):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get(download_url) response = await client.get(download_url)
@@ -36,4 +27,4 @@ class DownloadApi:
base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式 base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式
return base64_str return base64_str
else: else:
raise Exception('获取文件失败') raise Exception('获取文件失败')

View File

@@ -1,11 +1,6 @@
from libs.wechatpad_api.util.http_util import post_json,async_request
from typing import List, Dict, Any, Optional
class FriendApi: class FriendApi:
"""联系人API类处理所有与联系人相关的操作""" """联系人API类处理所有与联系人相关的操作"""
def __init__(self, base_url: str, token: str): def __init__(self, base_url: str, token: str):
self.base_url = base_url self.base_url = base_url
self.token = token self.token = token

View File

@@ -1,37 +1,34 @@
from libs.wechatpad_api.util.http_util import async_request,post_json,get_json from libs.wechatpad_api.util.http_util import post_json, get_json
class LoginApi: class LoginApi:
def __init__(self, base_url: str, token: str = None, admin_key: str = None): def __init__(self, base_url: str, token: str = None, admin_key: str = None):
''' """
Args: Args:
base_url: 原始路径 base_url: 原始路径
token: token token: token
admin_key: 管理员key admin_key: 管理员key
''' """
self.base_url = base_url self.base_url = base_url
self.token = token self.token = token
# self.admin_key = admin_key # self.admin_key = admin_key
def get_token(self, admin_key, day: int=365): def get_token(self, admin_key, day: int = 365):
# 获取普通token # 获取普通token
url = f"{self.base_url}/admin/GenAuthKey1" url = f'{self.base_url}/admin/GenAuthKey1'
json_data = { json_data = {'Count': 1, 'Days': day}
"Count": 1,
"Days": day
}
return post_json(base_url=url, token=admin_key, data=json_data) return post_json(base_url=url, token=admin_key, data=json_data)
def get_login_qr(self, Proxy: str = ""): def get_login_qr(self, Proxy: str = ''):
''' """
Args: Args:
Proxy:异地使用时代理 Proxy:异地使用时代理
Returns:json数据 Returns:json数据
''' """
""" """
{ {
@@ -49,54 +46,37 @@ class LoginApi:
} }
""" """
#获取登录二维码 # 获取登录二维码
url = f"{self.base_url}/login/GetLoginQrCodeNew" url = f'{self.base_url}/login/GetLoginQrCodeNew'
check = False check = False
if Proxy != "": if Proxy != '':
check = True check = True
json_data = { json_data = {'Check': check, 'Proxy': Proxy}
"Check": check,
"Proxy": Proxy
}
return post_json(base_url=url, token=self.token, data=json_data) return post_json(base_url=url, token=self.token, data=json_data)
def get_login_status(self): def get_login_status(self):
# 获取登录状态 # 获取登录状态
url = f'{self.base_url}/login/GetLoginStatus' url = f'{self.base_url}/login/GetLoginStatus'
return get_json(base_url=url, token=self.token) return get_json(base_url=url, token=self.token)
def logout(self): def logout(self):
# 退出登录 # 退出登录
url = f'{self.base_url}/login/LogOut' url = f'{self.base_url}/login/LogOut'
return post_json(base_url=url, token=self.token) return post_json(base_url=url, token=self.token)
def wake_up_login(self, Proxy: str = ''):
def wake_up_login(self, Proxy: str = ""):
# 唤醒登录 # 唤醒登录
url = f'{self.base_url}/login/WakeUpLogin' url = f'{self.base_url}/login/WakeUpLogin'
check = False check = False
if Proxy != "": if Proxy != '':
check = True check = True
json_data = { json_data = {'Check': check, 'Proxy': ''}
"Check": check,
"Proxy": ""
}
return post_json(base_url=url, token=self.token, data=json_data) return post_json(base_url=url, token=self.token, data=json_data)
def login(self, admin_key):
def login(self,admin_key):
login_status = self.get_login_status() login_status = self.get_login_status()
if login_status["Code"] == 300 and login_status["Text"] == "你已退出微信": if login_status['Code'] == 300 and login_status['Text'] == '你已退出微信':
print("token已经失效重新获取") print('token已经失效重新获取')
token_data = self.get_token(admin_key) token_data = self.get_token(admin_key)
self.token = token_data["Data"][0] self.token = token_data['Data'][0]

View File

@@ -1,5 +1,4 @@
from libs.wechatpad_api.util.http_util import post_json
from libs.wechatpad_api.util.http_util import async_request, post_json
class MessageApi: class MessageApi:
@@ -7,8 +6,8 @@ class MessageApi:
self.base_url = base_url self.base_url = base_url
self.token = token self.token = token
def post_text(self, to_wxid, content, ats: list= []): def post_text(self, to_wxid, content, ats: list = []):
''' """
Args: Args:
app_id: 微信id app_id: 微信id
@@ -18,106 +17,64 @@ class MessageApi:
Returns: Returns:
''' """
url = self.base_url + "/message/SendTextMessage" url = self.base_url + '/message/SendTextMessage'
"""发送文字消息""" """发送文字消息"""
json_data = { json_data = {
"MsgItem": [ 'MsgItem': [
{ {'AtWxIDList': ats, 'ImageContent': '', 'MsgType': 0, 'TextContent': content, 'ToUserName': to_wxid}
"AtWxIDList": ats, ]
"ImageContent": "", }
"MsgType": 0, return post_json(base_url=url, token=self.token, data=json_data)
"TextContent": content,
"ToUserName": to_wxid
}
]
}
return post_json(base_url=url, token=self.token, data=json_data)
def post_image(self, to_wxid, img_url, ats: list = []):
def post_image(self, to_wxid, img_url, ats: list= []):
"""发送图片消息""" """发送图片消息"""
# 这里好像可以尝试发送多个暂时未测试 # 这里好像可以尝试发送多个暂时未测试
json_data = { json_data = {
"MsgItem": [ 'MsgItem': [
{ {'AtWxIDList': ats, 'ImageContent': img_url, 'MsgType': 0, 'TextContent': '', 'ToUserName': to_wxid}
"AtWxIDList": ats,
"ImageContent": img_url,
"MsgType": 0,
"TextContent": '',
"ToUserName": to_wxid
}
] ]
} }
url = self.base_url + "/message/SendImageMessage" url = self.base_url + '/message/SendImageMessage'
return post_json(base_url=url, token=self.token, data=json_data) return post_json(base_url=url, token=self.token, data=json_data)
def post_voice(self, to_wxid, voice_data, voice_forma, voice_duration): def post_voice(self, to_wxid, voice_data, voice_forma, voice_duration):
"""发送语音消息""" """发送语音消息"""
json_data = { json_data = {
"ToUserName": to_wxid, 'ToUserName': to_wxid,
"VoiceData": voice_data, 'VoiceData': voice_data,
"VoiceFormat": voice_forma, 'VoiceFormat': voice_forma,
"VoiceSecond": voice_duration 'VoiceSecond': voice_duration,
} }
url = self.base_url + "/message/SendVoice" url = self.base_url + '/message/SendVoice'
return post_json(base_url=url, token=self.token, data=json_data) return post_json(base_url=url, token=self.token, data=json_data)
def post_name_card(self, alias, to_wxid, nick_name, name_card_wxid, flag): def post_name_card(self, alias, to_wxid, nick_name, name_card_wxid, flag):
"""发送名片消息""" """发送名片消息"""
param = { param = {
"CardAlias": alias, 'CardAlias': alias,
"CardFlag": flag, 'CardFlag': flag,
"CardNickName": nick_name, 'CardNickName': nick_name,
"CardWxId": name_card_wxid, 'CardWxId': name_card_wxid,
"ToUserName": to_wxid 'ToUserName': to_wxid,
} }
url = f"{self.base_url}/message/ShareCardMessage" url = f'{self.base_url}/message/ShareCardMessage'
return post_json(base_url=url, token=self.token, data=param) return post_json(base_url=url, token=self.token, data=param)
def post_emoji(self, to_wxid, emoji_md5, emoji_size:int=0): def post_emoji(self, to_wxid, emoji_md5, emoji_size: int = 0):
"""发送emoji消息""" """发送emoji消息"""
json_data = { json_data = {'EmojiList': [{'EmojiMd5': emoji_md5, 'EmojiSize': emoji_size, 'ToUserName': to_wxid}]}
"EmojiList": [ url = f'{self.base_url}/message/SendEmojiMessage'
{
"EmojiMd5": emoji_md5,
"EmojiSize": emoji_size,
"ToUserName": to_wxid
}
]
}
url = f"{self.base_url}/message/SendEmojiMessage"
return post_json(base_url=url, token=self.token, data=json_data) return post_json(base_url=url, token=self.token, data=json_data)
def post_app_msg(self, to_wxid,xml_data, contenttype:int=0): def post_app_msg(self, to_wxid, xml_data, contenttype: int = 0):
"""发送appmsg消息""" """发送appmsg消息"""
json_data = { json_data = {'AppList': [{'ContentType': contenttype, 'ContentXML': xml_data, 'ToUserName': to_wxid}]}
"AppList": [ url = f'{self.base_url}/message/SendAppMessage'
{
"ContentType": contenttype,
"ContentXML": xml_data,
"ToUserName": to_wxid
}
]
}
url = f"{self.base_url}/message/SendAppMessage"
return post_json(base_url=url, token=self.token, data=json_data) return post_json(base_url=url, token=self.token, data=json_data)
def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time): def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time):
"""撤回消息""" """撤回消息"""
param = { param = {'ClientMsgId': msg_id, 'CreateTime': create_time, 'NewMsgId': new_msg_id, 'ToUserName': to_wxid}
"ClientMsgId": msg_id, url = f'{self.base_url}/message/RevokeMsg'
"CreateTime": create_time, return post_json(base_url=url, token=self.token, data=param)
"NewMsgId": new_msg_id,
"ToUserName": to_wxid
}
url = f"{self.base_url}/message/RevokeMsg"
return post_json(base_url=url, token=self.token, data=param)

View File

@@ -12,12 +12,9 @@ class UserApi:
return get_json(base_url=url, token=self.token) return get_json(base_url=url, token=self.token)
def get_qr_code(self, recover:bool=True, style:int=8): def get_qr_code(self, recover: bool = True, style: int = 8):
"""获取自己的二维码""" """获取自己的二维码"""
param = { param = {'Recover': recover, 'Style': style}
"Recover": recover,
"Style": style
}
url = f'{self.base_url}/user/GetMyQRCode' url = f'{self.base_url}/user/GetMyQRCode'
return post_json(base_url=url, token=self.token, data=param) return post_json(base_url=url, token=self.token, data=param)
@@ -26,12 +23,8 @@ class UserApi:
url = f'{self.base_url}/equipment/GetSafetyInfo' url = f'{self.base_url}/equipment/GetSafetyInfo'
return post_json(base_url=url, token=self.token) return post_json(base_url=url, token=self.token)
async def update_head_img(self, head_img_base64):
async def update_head_img(self, head_img_base64):
"""修改头像""" """修改头像"""
param = { param = {'Base64': head_img_base64}
"Base64": head_img_base64
}
url = f'{self.base_url}/user/UploadHeadImage' url = f'{self.base_url}/user/UploadHeadImage'
return await async_request(base_url=url, token_key=self.token, json=param) return await async_request(base_url=url, token_key=self.token, json=param)

View File

@@ -1,4 +1,3 @@
from libs.wechatpad_api.api.login import LoginApi from libs.wechatpad_api.api.login import LoginApi
from libs.wechatpad_api.api.friend import FriendApi from libs.wechatpad_api.api.friend import FriendApi
from libs.wechatpad_api.api.message import MessageApi from libs.wechatpad_api.api.message import MessageApi
@@ -7,9 +6,6 @@ from libs.wechatpad_api.api.downloadpai import DownloadApi
from libs.wechatpad_api.api.chatroom import ChatRoomApi from libs.wechatpad_api.api.chatroom import ChatRoomApi
class WeChatPadClient: class WeChatPadClient:
def __init__(self, base_url, token, logger=None): def __init__(self, base_url, token, logger=None):
self._login_api = LoginApi(base_url, token) self._login_api = LoginApi(base_url, token)
@@ -20,16 +16,16 @@ class WeChatPadClient:
self._chatroom_api = ChatRoomApi(base_url, token) self._chatroom_api = ChatRoomApi(base_url, token)
self.logger = logger self.logger = logger
def get_token(self,admin_key, day: int): def get_token(self, admin_key, day: int):
'''获取token''' """获取token"""
return self._login_api.get_token(admin_key, day) return self._login_api.get_token(admin_key, day)
def get_login_qr(self, Proxy:str=""): def get_login_qr(self, Proxy: str = ''):
"""登录二维码""" """登录二维码"""
return self._login_api.get_login_qr(Proxy=Proxy) return self._login_api.get_login_qr(Proxy=Proxy)
def awaken_login(self, Proxy:str=""): def awaken_login(self, Proxy: str = ''):
'''唤醒登录''' """唤醒登录"""
return self._login_api.wake_up_login(Proxy=Proxy) return self._login_api.wake_up_login(Proxy=Proxy)
def log_out(self): def log_out(self):
@@ -40,59 +36,57 @@ class WeChatPadClient:
"""获取登录状态""" """获取登录状态"""
return self._login_api.get_login_status() return self._login_api.get_login_status()
def send_text_message(self, to_wxid, message, ats: list=[]): def send_text_message(self, to_wxid, message, ats: list = []):
"""发送文本消息""" """发送文本消息"""
return self._message_api.post_text(to_wxid, message, ats) return self._message_api.post_text(to_wxid, message, ats)
def send_image_message(self, to_wxid, img_url, ats: list=[]): def send_image_message(self, to_wxid, img_url, ats: list = []):
"""发送图片消息""" """发送图片消息"""
return self._message_api.post_image(to_wxid, img_url, ats) return self._message_api.post_image(to_wxid, img_url, ats)
def send_voice_message(self, to_wxid, voice_data, voice_forma, voice_duration): def send_voice_message(self, to_wxid, voice_data, voice_forma, voice_duration):
"""发送音频消息""" """发送音频消息"""
return self._message_api.post_voice(to_wxid, voice_data, voice_forma, voice_duration) return self._message_api.post_voice(to_wxid, voice_data, voice_forma, voice_duration)
def send_app_message(self, to_wxid, app_message, type): def send_app_message(self, to_wxid, app_message, type):
"""发送app消息""" """发送app消息"""
return self._message_api.post_app_msg(to_wxid, app_message, type) return self._message_api.post_app_msg(to_wxid, app_message, type)
def send_emoji_message(self, to_wxid, emoji_md5, emoji_size): def send_emoji_message(self, to_wxid, emoji_md5, emoji_size):
"""发送emoji消息""" """发送emoji消息"""
return self._message_api.post_emoji(to_wxid,emoji_md5,emoji_size) return self._message_api.post_emoji(to_wxid, emoji_md5, emoji_size)
def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time): def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time):
"""撤回消息""" """撤回消息"""
return self._message_api.revoke_msg(to_wxid, msg_id, new_msg_id, create_time) return self._message_api.revoke_msg(to_wxid, msg_id, new_msg_id, create_time)
def get_profile(self): def get_profile(self):
"""获取用户信息""" """获取用户信息"""
return self._user_api.get_profile() return self._user_api.get_profile()
def get_qr_code(self, recover:bool=True, style:int=8): def get_qr_code(self, recover: bool = True, style: int = 8):
"""获取用户二维码""" """获取用户二维码"""
return self._user_api.get_qr_code(recover=recover, style=style) return self._user_api.get_qr_code(recover=recover, style=style)
def get_safety_info(self): def get_safety_info(self):
"""获取设备信息""" """获取设备信息"""
return self._user_api.get_safety_info() return self._user_api.get_safety_info()
def update_head_img(self, head_img_base64): def update_head_img(self, head_img_base64):
"""上传用户头像""" """上传用户头像"""
return self._user_api.update_head_img(head_img_base64) return self._user_api.update_head_img(head_img_base64)
def cdn_download(self, aeskey, file_type, file_url): def cdn_download(self, aeskey, file_type, file_url):
"""cdn下载""" """cdn下载"""
return self._download_api.send_download( aeskey, file_type, file_url) return self._download_api.send_download(aeskey, file_type, file_url)
def get_msg_voice(self,buf_id, length, msgid): def get_msg_voice(self, buf_id, length, msgid):
"""下载语音""" """下载语音"""
return self._download_api.get_msg_voice(buf_id, length, msgid) return self._download_api.get_msg_voice(buf_id, length, msgid)
async def download_base64(self,url): async def download_base64(self, url):
return await self._download_api.download_url_to_base64(download_url=url) return await self._download_api.download_url_to_base64(download_url=url)
def get_chatroom_member_detail(self, chatroom_name): def get_chatroom_member_detail(self, chatroom_name):
"""查看群成员详情""" """查看群成员详情"""
return self._chatroom_api.get_chatroom_member_detail(chatroom_name) return self._chatroom_api.get_chatroom_member_detail(chatroom_name)

View File

@@ -1,10 +1,9 @@
import requests import requests
import aiohttp
def post_json(base_url, token, data=None): def post_json(base_url, token, data=None):
headers = { headers = {'Content-Type': 'application/json'}
'Content-Type': 'application/json'
}
url = base_url + f'?key={token}' url = base_url + f'?key={token}'
@@ -18,14 +17,12 @@ def post_json(base_url, token, data=None):
else: else:
raise RuntimeError(response.text) raise RuntimeError(response.text)
except Exception as e: except Exception as e:
print(f"http请求失败, url={url}, exception={e}") print(f'http请求失败, url={url}, exception={e}')
raise RuntimeError(str(e)) raise RuntimeError(str(e))
def get_json(base_url, token):
headers = {
'Content-Type': 'application/json'
}
def get_json(base_url, token):
headers = {'Content-Type': 'application/json'}
url = base_url + f'?key={token}' url = base_url + f'?key={token}'
@@ -39,21 +36,18 @@ def get_json(base_url, token):
else: else:
raise RuntimeError(response.text) raise RuntimeError(response.text)
except Exception as e: except Exception as e:
print(f"http请求失败, url={url}, exception={e}") print(f'http请求失败, url={url}, exception={e}')
raise RuntimeError(str(e)) raise RuntimeError(str(e))
import aiohttp
import asyncio
async def async_request( async def async_request(
base_url: str, base_url: str,
token_key: str, token_key: str,
method: str = 'POST', method: str = 'POST',
params: dict = None, params: dict = None,
# headers: dict = None, # headers: dict = None,
data: dict = None, data: dict = None,
json: dict = None json: dict = None,
): ):
""" """
通用异步请求函数 通用异步请求函数
@@ -67,18 +61,11 @@ async def async_request(
:param json: JSON数据 :param json: JSON数据
:return: 响应文本 :return: 响应文本
""" """
headers = { headers = {'Content-Type': 'application/json'}
'Content-Type': 'application/json' url = f'{base_url}?key={token_key}'
}
url = f"{base_url}?key={token_key}"
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.request( async with session.request(
method=method, method=method, url=url, params=params, headers=headers, data=data, json=json
url=url,
params=params,
headers=headers,
data=data,
json=json
) as response: ) as response:
response.raise_for_status() # 如果状态码不是200抛出异常 response.raise_for_status() # 如果状态码不是200抛出异常
result = await response.json() result = await response.json()
@@ -89,4 +76,3 @@ async def async_request(
# return await result # return await result
# else: # else:
# raise RuntimeError("请求失败",response.text) # raise RuntimeError("请求失败",response.text)

View File

@@ -1,31 +1,34 @@
import qrcode import qrcode
def print_green(text): def print_green(text):
print(f"\033[32m{text}\033[0m") print(f'\033[32m{text}\033[0m')
def print_yellow(text): def print_yellow(text):
print(f"\033[33m{text}\033[0m") print(f'\033[33m{text}\033[0m')
def print_red(text): def print_red(text):
print(f"\033[31m{text}\033[0m") print(f'\033[31m{text}\033[0m')
def make_and_print_qr(url): def make_and_print_qr(url):
"""生成并打印二维码 """生成并打印二维码
Args: Args:
url: 需要生成二维码的URL字符串 url: 需要生成二维码的URL字符串
Returns: Returns:
None None
功能: 功能:
1. 在终端打印二维码的ASCII图形 1. 在终端打印二维码的ASCII图形
2. 同时提供在线二维码生成链接作为备选 2. 同时提供在线二维码生成链接作为备选
""" """
print_green("请扫描下方二维码登录") print_green('请扫描下方二维码登录')
qr = qrcode.QRCode() qr = qrcode.QRCode()
qr.add_data(url) qr.add_data(url)
qr.make() qr.make()
qr.print_ascii(invert=True) qr.print_ascii(invert=True)
print_green(f"也可以访问下方链接获取二维码:\nhttps://api.qrserver.com/v1/create-qr-code/?data={url}") print_green(f'也可以访问下方链接获取二维码:\nhttps://api.qrserver.com/v1/create-qr-code/?data={url}')

View File

@@ -8,7 +8,7 @@ from quart import Quart
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from typing import Callable, Dict, Any from typing import Callable, Dict, Any
from .wecomevent import WecomEvent from .wecomevent import WecomEvent
from pkg.platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
import aiofiles import aiofiles
@@ -57,7 +57,7 @@ class WecomClient:
if 'access_token' in data: if 'access_token' in data:
return data['access_token'] return data['access_token']
else: else:
await self.logger.error(f"获取accesstoken失败:{response.json()}") await self.logger.error(f'获取accesstoken失败:{response.json()}')
raise Exception(f'未获取access token: {data}') raise Exception(f'未获取access token: {data}')
async def get_users(self): async def get_users(self):
@@ -129,7 +129,7 @@ class WecomClient:
response = await client.post(url, json=params) response = await client.post(url, json=params)
data = response.json() data = response.json()
except Exception as e: except Exception as e:
await self.logger.error(f"发送图片失败:{data}") await self.logger.error(f'发送图片失败:{data}')
raise Exception('Failed to send image: ' + str(e)) raise Exception('Failed to send image: ' + str(e))
# 企业微信错误码40014和42001代表accesstoken问题 # 企业微信错误码40014和42001代表accesstoken问题
@@ -164,7 +164,7 @@ class WecomClient:
self.access_token = await self.get_access_token(self.secret) self.access_token = await self.get_access_token(self.secret)
return await self.send_private_msg(user_id, agent_id, content) return await self.send_private_msg(user_id, agent_id, content)
if data['errcode'] != 0: if data['errcode'] != 0:
await self.logger.error(f"发送消息失败:{data}") await self.logger.error(f'发送消息失败:{data}')
raise Exception('Failed to send message: ' + str(data)) raise Exception('Failed to send message: ' + str(data))
async def handle_callback_request(self): async def handle_callback_request(self):
@@ -181,7 +181,7 @@ class WecomClient:
echostr = request.args.get('echostr') echostr = request.args.get('echostr')
ret, reply_echo_str = wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr) ret, reply_echo_str = wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
if ret != 0: if ret != 0:
await self.logger.error("验证失败") await self.logger.error('验证失败')
raise Exception(f'验证失败,错误码: {ret}') raise Exception(f'验证失败,错误码: {ret}')
return reply_echo_str return reply_echo_str
@@ -189,9 +189,8 @@ class WecomClient:
encrypt_msg = await request.data encrypt_msg = await request.data
ret, xml_msg = wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce) ret, xml_msg = wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
if ret != 0: if ret != 0:
await self.logger.error("消息解密失败") await self.logger.error('消息解密失败')
raise Exception(f'消息解密失败,错误码: {ret}') raise Exception(f'消息解密失败,错误码: {ret}')
# 解析消息并处理 # 解析消息并处理
message_data = await self.get_message(xml_msg) message_data = await self.get_message(xml_msg)
@@ -202,7 +201,7 @@ class WecomClient:
return 'success' return 'success'
except Exception as e: except Exception as e:
await self.logger.error(f"Error in handle_callback_request: {traceback.format_exc()}") await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
return f'Error processing request: {str(e)}', 400 return f'Error processing request: {str(e)}', 400
async def run_task(self, host: str, port: int, *args, **kwargs): async def run_task(self, host: str, port: int, *args, **kwargs):
@@ -301,7 +300,7 @@ class WecomClient:
except binascii.Error as e: except binascii.Error as e:
raise ValueError(f'Invalid base64 string: {str(e)}') raise ValueError(f'Invalid base64 string: {str(e)}')
else: else:
await self.logger.error("Image对象出错") await self.logger.error('Image对象出错')
raise ValueError('image对象出错') raise ValueError('image对象出错')
# 设置 multipart/form-data 格式的文件 # 设置 multipart/form-data 格式的文件
@@ -325,7 +324,7 @@ class WecomClient:
self.access_token = await self.get_access_token(self.secret) self.access_token = await self.get_access_token(self.secret)
media_id = await self.upload_to_work(image) media_id = await self.upload_to_work(image)
if data.get('errcode', 0) != 0: if data.get('errcode', 0) != 0:
await self.logger.error(f"上传图片失败:{data}") await self.logger.error(f'上传图片失败:{data}')
raise Exception('failed to upload file') raise Exception('failed to upload file')
media_id = data.get('media_id') media_id = data.get('media_id')

View File

@@ -8,7 +8,7 @@ from quart import Quart
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from typing import Callable from typing import Callable
from .wecomcsevent import WecomCSEvent from .wecomcsevent import WecomCSEvent
from pkg.platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
import aiofiles import aiofiles
@@ -187,7 +187,7 @@ class WecomCSClient:
self.access_token = await self.get_access_token(self.secret) self.access_token = await self.get_access_token(self.secret)
return await self.send_text_msg(open_kfid, external_userid, msgid, content) return await self.send_text_msg(open_kfid, external_userid, msgid, content)
if data['errcode'] != 0: if data['errcode'] != 0:
await self.logger.error(f"发送消息失败:{data}") await self.logger.error(f'发送消息失败:{data}')
raise Exception('Failed to send message') raise Exception('Failed to send message')
return data return data
@@ -227,7 +227,7 @@ class WecomCSClient:
return 'success' return 'success'
except Exception as e: except Exception as e:
if self.logger: if self.logger:
await self.logger.error(f"Error in handle_callback_request: {traceback.format_exc()}") await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
else: else:
traceback.print_exc() traceback.print_exc()
return f'Error processing request: {str(e)}', 400 return f'Error processing request: {str(e)}', 400

10
main.py
View File

@@ -47,13 +47,13 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
if not args.skip_plugin_deps_check: if not args.skip_plugin_deps_check:
await deps.precheck_plugin_deps() await deps.precheck_plugin_deps()
# 检查pydantic版本如果没有 pydantic.v1则把 pydantic 映射为 v1 # # 检查pydantic版本如果没有 pydantic.v1则把 pydantic 映射为 v1
import pydantic.version # import pydantic.version
if pydantic.version.VERSION < '2.0': # if pydantic.version.VERSION < '2.0':
import pydantic # import pydantic
sys.modules['pydantic.v1'] = pydantic # sys.modules['pydantic.v1'] = pydantic
# 检查配置文件 # 检查配置文件

View File

@@ -1,10 +1,11 @@
from __future__ import annotations from __future__ import annotations
import base64
import quart import quart
from .....core import taskmgr from .....core import taskmgr
from .. import group from .. import group
from langbot_plugin.runtime.plugin.mgr import PluginInstallSource
@group.group_class('plugins', '/api/v1/plugins') @group.group_class('plugins', '/api/v1/plugins')
@@ -12,35 +13,22 @@ class PluginsRouterGroup(group.RouterGroup):
async def initialize(self) -> None: async def initialize(self) -> None:
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN) @self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str: async def _() -> str:
plugins = self.ap.plugin_mgr.plugins() plugins = await self.ap.plugin_connector.list_plugins()
plugins_data = [plugin.model_dump() for plugin in plugins] return self.success(data={'plugins': plugins})
return self.success(data={'plugins': plugins_data})
@self.route( @self.route(
'/<author>/<plugin_name>/toggle', '/<author>/<plugin_name>/upgrade',
methods=['PUT'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _(author: str, plugin_name: str) -> str:
data = await quart.request.json
target_enabled = data.get('target_enabled')
await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled)
return self.success()
@self.route(
'/<author>/<plugin_name>/update',
methods=['POST'], methods=['POST'],
auth_type=group.AuthType.USER_TOKEN, auth_type=group.AuthType.USER_TOKEN,
) )
async def _(author: str, plugin_name: str) -> str: async def _(author: str, plugin_name: str) -> str:
ctx = taskmgr.TaskContext.new() ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task( wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.update_plugin(plugin_name, task_context=ctx), self.ap.plugin_connector.upgrade_plugin(author, plugin_name, task_context=ctx),
kind='plugin-operation', kind='plugin-operation',
name=f'plugin-update-{plugin_name}', name=f'plugin-upgrade-{plugin_name}',
label=f'更新插件 {plugin_name}', label=f'Upgrading plugin {plugin_name}',
context=ctx, context=ctx,
) )
return self.success(data={'task_id': wrapper.id}) return self.success(data={'task_id': wrapper.id})
@@ -52,17 +40,17 @@ class PluginsRouterGroup(group.RouterGroup):
) )
async def _(author: str, plugin_name: str) -> str: async def _(author: str, plugin_name: str) -> str:
if quart.request.method == 'GET': if quart.request.method == 'GET':
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name) plugin = await self.ap.plugin_connector.get_plugin_info(author, plugin_name)
if plugin is None: if plugin is None:
return self.http_status(404, -1, 'plugin not found') return self.http_status(404, -1, 'plugin not found')
return self.success(data={'plugin': plugin.model_dump()}) return self.success(data={'plugin': plugin})
elif quart.request.method == 'DELETE': elif quart.request.method == 'DELETE':
ctx = taskmgr.TaskContext.new() ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task( wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_mgr.uninstall_plugin(plugin_name, task_context=ctx), self.ap.plugin_connector.delete_plugin(author, plugin_name, task_context=ctx),
kind='plugin-operation', kind='plugin-operation',
name=f'plugin-remove-{plugin_name}', name=f'plugin-remove-{plugin_name}',
label=f'删除插件 {plugin_name}', label=f'Removing plugin {plugin_name}',
context=ctx, context=ctx,
) )
@@ -74,24 +62,19 @@ class PluginsRouterGroup(group.RouterGroup):
auth_type=group.AuthType.USER_TOKEN, auth_type=group.AuthType.USER_TOKEN,
) )
async def _(author: str, plugin_name: str) -> quart.Response: async def _(author: str, plugin_name: str) -> quart.Response:
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name) plugin = await self.ap.plugin_connector.get_plugin_info(author, plugin_name)
if plugin is None: if plugin is None:
return self.http_status(404, -1, 'plugin not found') return self.http_status(404, -1, 'plugin not found')
if quart.request.method == 'GET': if quart.request.method == 'GET':
return self.success(data={'config': plugin.plugin_config}) return self.success(data={'config': plugin['plugin_config']})
elif quart.request.method == 'PUT': elif quart.request.method == 'PUT':
data = await quart.request.json data = await quart.request.json
await self.ap.plugin_mgr.set_plugin_config(plugin, data) await self.ap.plugin_connector.set_plugin_config(author, plugin_name, data)
return self.success(data={}) return self.success(data={})
@self.route('/reorder', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
data = await quart.request.json
await self.ap.plugin_mgr.reorder_plugins(data.get('plugins'))
return self.success()
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) @self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str: async def _() -> str:
data = await quart.request.json data = await quart.request.json
@@ -102,7 +85,47 @@ class PluginsRouterGroup(group.RouterGroup):
self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx), self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx),
kind='plugin-operation', kind='plugin-operation',
name='plugin-install-github', name='plugin-install-github',
label=f'安装插件 ...{short_source_str}', label=f'Installing plugin from github ...{short_source_str}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
@self.route('/install/marketplace', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
data = await quart.request.json
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_connector.install_plugin(PluginInstallSource.MARKETPLACE, data, task_context=ctx),
kind='plugin-operation',
name='plugin-install-marketplace',
label=f'Installing plugin from marketplace ...{data}',
context=ctx,
)
return self.success(data={'task_id': wrapper.id})
@self.route('/install/local', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
file = (await quart.request.files).get('file')
if file is None:
return self.http_status(400, -1, 'file is required')
file_bytes = file.read()
file_base64 = base64.b64encode(file_bytes).decode('utf-8')
data = {
'plugin_file': file_base64,
}
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self.ap.plugin_connector.install_plugin(PluginInstallSource.LOCAL, data, task_context=ctx),
kind='plugin-operation',
name='plugin-install-local',
label=f'Installing plugin from local ...{file.filename}',
context=ctx, context=ctx,
) )

View File

@@ -14,6 +14,11 @@ class SystemRouterGroup(group.RouterGroup):
'version': constants.semantic_version, 'version': constants.semantic_version,
'debug': constants.debug_mode, 'debug': constants.debug_mode,
'enabled_platform_count': len(self.ap.platform_mgr.get_running_adapters()), 'enabled_platform_count': len(self.ap.platform_mgr.get_running_adapters()),
'cloud_service_url': (
self.ap.instance_config.data['plugin']['cloud_service_url']
if 'cloud_service_url' in self.ap.instance_config.data['plugin']
else 'https://space.langbot.app'
),
} }
) )
@@ -35,16 +40,7 @@ class SystemRouterGroup(group.RouterGroup):
return self.success(data=task.to_dict()) return self.success(data=task.to_dict())
@self.route('/reload', methods=['POST'], auth_type=group.AuthType.USER_TOKEN) @self.route('/debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
json_data = await quart.request.json
scope = json_data.get('scope')
await self.ap.reload(scope=scope)
return self.success()
@self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str: async def _() -> str:
if not constants.debug_mode: if not constants.debug_mode:
return self.http_status(403, 403, 'Forbidden') return self.http_status(403, 403, 'Forbidden')
@@ -54,3 +50,39 @@ class SystemRouterGroup(group.RouterGroup):
ap = self.ap ap = self.ap
return self.success(data=exec(py_code, {'ap': ap})) return self.success(data=exec(py_code, {'ap': ap}))
@self.route('/debug/tools/call', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, 'Forbidden')
data = await quart.request.json
return self.success(
data=await self.ap.tool_mgr.execute_func_call(data['tool_name'], data['tool_parameters'])
)
@self.route(
'/debug/plugin/action',
methods=['POST'],
auth_type=group.AuthType.USER_TOKEN,
)
async def _() -> str:
if not constants.debug_mode:
return self.http_status(403, 403, 'Forbidden')
data = await quart.request.json
class AnoymousAction:
value = 'anonymous_action'
def __init__(self, value: str):
self.value = value
resp = await self.ap.plugin_connector.handler.call_action(
AnoymousAction(data['action']),
data['data'],
timeout=data.get('timeout', 10),
)
return self.success(data=resp)

View File

@@ -17,15 +17,19 @@ class BotService:
def __init__(self, ap: app.Application) -> None: def __init__(self, ap: app.Application) -> None:
self.ap = ap self.ap = ap
async def get_bots(self) -> list[dict]: async def get_bots(self, include_secret: bool = True) -> list[dict]:
"""获取所有机器人""" """获取所有机器人"""
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot)) result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot))
bots = result.all() bots = result.all()
return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot) for bot in bots] masked_columns = []
if not include_secret:
masked_columns = ['adapter_config']
async def get_bot(self, bot_uuid: str) -> dict | None: return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot, masked_columns) for bot in bots]
async def get_bot(self, bot_uuid: str, include_secret: bool = True) -> dict | None:
"""获取机器人""" """获取机器人"""
result = await self.ap.persistence_mgr.execute_async( result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid) sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
@@ -36,7 +40,27 @@ class BotService:
if bot is None: if bot is None:
return None return None
return self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot) masked_columns = []
if not include_secret:
masked_columns = ['adapter_config']
return self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot, masked_columns)
async def get_runtime_bot_info(self, bot_uuid: str, include_secret: bool = True) -> dict:
"""获取机器人运行时信息"""
persistence_bot = await self.get_bot(bot_uuid, include_secret)
if persistence_bot is None:
raise Exception('Bot not found')
adapter_runtime_values = {}
runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
if runtime_bot is not None:
adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id
persistence_bot['adapter_runtime_values'] = adapter_runtime_values
return persistence_bot
async def create_bot(self, bot_data: dict) -> str: async def create_bot(self, bot_data: dict) -> str:
"""创建机器人""" """创建机器人"""

View File

@@ -7,7 +7,7 @@ from ....core import app
from ....entity.persistence import model as persistence_model from ....entity.persistence import model as persistence_model
from ....entity.persistence import pipeline as persistence_pipeline from ....entity.persistence import pipeline as persistence_pipeline
from ....provider.modelmgr import requester as model_requester from ....provider.modelmgr import requester as model_requester
from ....provider import entities as llm_entities from langbot_plugin.api.entities.builtin.provider import message as provider_message
class ModelsService: class ModelsService:
@@ -16,11 +16,19 @@ class ModelsService:
def __init__(self, ap: app.Application) -> None: def __init__(self, ap: app.Application) -> None:
self.ap = ap self.ap = ap
async def get_llm_models(self) -> list[dict]: async def get_llm_models(self, include_secret: bool = True) -> list[dict]:
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel)) result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
models = result.all() models = result.all()
return [self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) for model in models]
masked_columns = []
if not include_secret:
masked_columns = ['api_keys']
return [
self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model, masked_columns)
for model in models
]
async def create_llm_model(self, model_data: dict) -> str: async def create_llm_model(self, model_data: dict) -> str:
model_data['uuid'] = str(uuid.uuid4()) model_data['uuid'] = str(uuid.uuid4())
@@ -99,7 +107,7 @@ class ModelsService:
await runtime_llm_model.requester.invoke_llm( await runtime_llm_model.requester.invoke_llm(
query=None, query=None,
model=runtime_llm_model, model=runtime_llm_model,
messages=[llm_entities.Message(role='user', content='Hello, world!')], messages=[provider_message.Message(role='user', content='Hello, world!')],
funcs=[], funcs=[],
extra_args={}, extra_args={},
) )

View File

@@ -2,9 +2,12 @@ from __future__ import annotations
import typing import typing
from ..core import app, entities as core_entities from ..core import app
from . import entities, operator, errors from . import operator
from ..utils import importutil from ..utils import importutil
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
# 引入所有算子以便注册 # 引入所有算子以便注册
from . import operators from . import operators
@@ -13,13 +16,11 @@ importutil.import_modules_in_pkg(operators)
class CommandManager: class CommandManager:
"""命令管理器"""
ap: app.Application ap: app.Application
cmd_list: list[operator.CommandOperator] cmd_list: list[operator.CommandOperator]
""" """
运行时命令列表,扁平存储,各个对象包含对应的子节点引用 Runtime command list, flat storage, each object contains a reference to the corresponding child node
""" """
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
@@ -55,43 +56,28 @@ class CommandManager:
async def _execute( async def _execute(
self, self,
context: entities.ExecuteContext, context: command_context.ExecuteContext,
operator_list: list[operator.CommandOperator], operator_list: list[operator.CommandOperator],
operator: operator.CommandOperator = None, operator: operator.CommandOperator = None,
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
"""执行命令""" """执行命令"""
found = False command_list = await self.ap.plugin_connector.list_commands()
if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
for oper in operator_list:
if (context.crt_params[0] == oper.name or context.crt_params[0] in oper.alias) and (
oper.parent_class is None or oper.parent_class == operator.__class__
):
found = True
context.crt_command = context.crt_params[0] for command in command_list:
context.crt_params = context.crt_params[1:] if command.metadata.name == context.command:
async for ret in self.ap.plugin_connector.execute_command(context):
async for ret in self._execute(context, oper.children, oper): yield ret
yield ret break
break else:
yield command_context.CommandReturn(error=command_errors.CommandNotFoundError(context.command))
if not found: # 如果下一个参数未在此节点的子节点中找到,则执行此节点或者报错
if operator is None:
yield entities.CommandReturn(error=errors.CommandNotFoundError(context.crt_params[0]))
else:
if operator.lowest_privilege > context.privilege:
yield entities.CommandReturn(error=errors.CommandPrivilegeError(operator.name))
else:
async for ret in operator.execute(context):
yield ret
async def execute( async def execute(
self, self,
command_text: str, command_text: str,
query: core_entities.Query, query: pipeline_query.Query,
session: core_entities.Session, session: provider_session.Session,
) -> typing.AsyncGenerator[entities.CommandReturn, None]: ) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
"""执行命令""" """执行命令"""
privilege = 1 privilege = 1
@@ -99,8 +85,8 @@ class CommandManager:
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']: if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
privilege = 2 privilege = 2
ctx = entities.ExecuteContext( ctx = command_context.ExecuteContext(
query=query, query_id=query.query_id,
session=session, session=session,
command_text=command_text, command_text=command_text,
command='', command='',
@@ -110,5 +96,9 @@ class CommandManager:
privilege=privilege, privilege=privilege,
) )
ctx.command = ctx.params[0]
ctx.shift()
async for ret in self._execute(ctx, self.cmd_list): async for ret in self._execute(ctx, self.cmd_list):
yield ret yield ret

View File

@@ -1,74 +0,0 @@
from __future__ import annotations
import typing
import pydantic.v1 as pydantic
from ..core import entities as core_entities
from . import errors
from ..platform.types import message as platform_message
class CommandReturn(pydantic.BaseModel):
"""命令返回值"""
text: typing.Optional[str] = None
"""文本
"""
image: typing.Optional[platform_message.Image] = None
"""弃用"""
image_url: typing.Optional[str] = None
"""图片链接
"""
error: typing.Optional[errors.CommandError] = None
"""错误
"""
class Config:
arbitrary_types_allowed = True
class ExecuteContext(pydantic.BaseModel):
"""单次命令执行上下文"""
query: core_entities.Query
"""本次消息的请求对象"""
session: core_entities.Session
"""本次消息所属的会话对象"""
command_text: str
"""命令完整文本"""
command: str
"""命令名称"""
crt_command: str
"""当前命令
多级命令中crt_command为当前命令command为根命令。
例如:!plugin on Webwlkr
处理到plugin时command为plugincrt_command为plugin
处理到on时command为plugincrt_command为on
"""
params: list[str]
"""命令参数
整个命令以空格分割后的参数列表
"""
crt_params: list[str]
"""当前命令参数
多级命令中crt_params为当前命令参数params为根命令参数。
例如:!plugin on Webwlkr
处理到plugin时params为['on', 'Webwlkr']crt_params为['on', 'Webwlkr']
处理到on时params为['on', 'Webwlkr']crt_params为['Webwlkr']
"""
privilege: int
"""发起人权限"""

View File

@@ -1,26 +0,0 @@
class CommandError(Exception):
def __init__(self, message: str = None):
self.message = message
def __str__(self):
return self.message
class CommandNotFoundError(CommandError):
def __init__(self, message: str = None):
super().__init__('未知命令: ' + message)
class CommandPrivilegeError(CommandError):
def __init__(self, message: str = None):
super().__init__('权限不足: ' + message)
class ParamNotEnoughError(CommandError):
def __init__(self, message: str = None):
super().__init__('参数不足: ' + message)
class CommandOperationError(CommandError):
def __init__(self, message: str = None):
super().__init__('操作失败: ' + message)

View File

@@ -4,7 +4,7 @@ import typing
import abc import abc
from ..core import app from ..core import app
from . import entities from langbot_plugin.api.entities.builtin.command import context as command_context
preregistered_operators: list[typing.Type[CommandOperator]] = [] preregistered_operators: list[typing.Type[CommandOperator]] = []
@@ -95,16 +95,18 @@ class CommandOperator(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
"""实现此方法以执行命令 """实现此方法以执行命令
支持多次yield以返回多个结果。 支持多次yield以返回多个结果。
例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。 例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。
Args: Args:
context (entities.ExecuteContext): 命令执行上下文 context (command_context.ExecuteContext): 命令执行上下文
Yields: Yields:
entities.CommandReturn: 命令返回封装 command_context.CommandReturn: 命令返回封装
""" """
pass pass

View File

@@ -2,14 +2,17 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, errors from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
@operator.operator_class(name='cmd', help='显示命令列表', usage='!cmd\n!cmd <命令名称>') @operator.operator_class(name='cmd', help='显示命令列表', usage='!cmd\n!cmd <命令名称>')
class CmdOperator(operator.CommandOperator): class CmdOperator(operator.CommandOperator):
"""命令列表""" """命令列表"""
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
"""执行""" """执行"""
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
reply_str = '当前所有命令: \n\n' reply_str = '当前所有命令: \n\n'
@@ -20,7 +23,7 @@ class CmdOperator(operator.CommandOperator):
reply_str += '\n使用 !cmd <命令名称> 查看命令的详细帮助' reply_str += '\n使用 !cmd <命令名称> 查看命令的详细帮助'
yield entities.CommandReturn(text=reply_str.strip()) yield command_context.CommandReturn(text=reply_str.strip())
else: else:
cmd_name = context.crt_params[0] cmd_name = context.crt_params[0]
@@ -33,9 +36,9 @@ class CmdOperator(operator.CommandOperator):
break break
if cmd is None: if cmd is None:
yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name)) yield command_context.CommandReturn(error=command_errors.CommandNotFoundError(cmd_name))
else: else:
reply_str = f'{cmd.name}: {cmd.help}\n\n' reply_str = f'{cmd.name}: {cmd.help}\n\n'
reply_str += f'使用方法: \n{cmd.usage}' reply_str += f'使用方法: \n{cmd.usage}'
yield entities.CommandReturn(text=reply_str.strip()) yield command_context.CommandReturn(text=reply_str.strip())

View File

@@ -2,23 +2,26 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, errors from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
@operator.operator_class(name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all') @operator.operator_class(name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all')
class DelOperator(operator.CommandOperator): class DelOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
if context.session.conversations: if context.session.conversations:
delete_index = 0 delete_index = 0
if len(context.crt_params) > 0: if len(context.crt_params) > 0:
try: try:
delete_index = int(context.crt_params[0]) delete_index = int(context.crt_params[0])
except Exception: except Exception:
yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数')) yield command_context.CommandReturn(error=command_errors.CommandOperationError('索引必须是整数'))
return return
if delete_index < 0 or delete_index >= len(context.session.conversations): if delete_index < 0 or delete_index >= len(context.session.conversations):
yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围')) yield command_context.CommandReturn(error=command_errors.CommandOperationError('索引超出范围'))
return return
# 倒序 # 倒序
@@ -29,15 +32,17 @@ class DelOperator(operator.CommandOperator):
del context.session.conversations[to_delete_index] del context.session.conversations[to_delete_index]
yield entities.CommandReturn(text=f'已删除对话: {delete_index}') yield command_context.CommandReturn(text=f'已删除对话: {delete_index}')
else: else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
@operator.operator_class(name='all', help='删除此会话的所有历史记录', parent_class=DelOperator) @operator.operator_class(name='all', help='删除此会话的所有历史记录', parent_class=DelOperator)
class DelAllOperator(operator.CommandOperator): class DelAllOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
context.session.conversations = [] context.session.conversations = []
context.session.using_conversation = None context.session.using_conversation = None
yield entities.CommandReturn(text='已删除所有对话') yield command_context.CommandReturn(text='已删除所有对话')

View File

@@ -1,19 +1,20 @@
from __future__ import annotations from __future__ import annotations
from typing import AsyncGenerator from typing import AsyncGenerator
from .. import operator, entities from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context
@operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func') @operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func')
class FuncOperator(operator.CommandOperator): class FuncOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> AsyncGenerator[command_context.CommandReturn, None]:
reply_str = '当前已启用的内容函数: \n\n' reply_str = '当前已启用的内容函数: \n\n'
index = 1 index = 1
all_functions = await self.ap.tool_mgr.get_all_functions( all_functions = await self.ap.tool_mgr.get_all_tools()
plugin_enabled=True,
)
for func in all_functions: for func in all_functions:
reply_str += '{}. {}:\n{}\n\n'.format( reply_str += '{}. {}:\n{}\n\n'.format(
@@ -23,4 +24,4 @@ class FuncOperator(operator.CommandOperator):
) )
index += 1 index += 1
yield entities.CommandReturn(text=reply_str) yield command_context.CommandReturn(text=reply_str)

View File

@@ -2,14 +2,17 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context
@operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>') @operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>')
class HelpOperator(operator.CommandOperator): class HelpOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接https://langbot.app' help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接https://langbot.app'
help += '\n发送命令 !cmd 可查看命令列表' help += '\n发送命令 !cmd 可查看命令列表'
yield entities.CommandReturn(text=help) yield command_context.CommandReturn(text=help)

View File

@@ -3,26 +3,31 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, errors from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
@operator.operator_class(name='last', help='切换到前一个对话', usage='!last') @operator.operator_class(name='last', help='切换到前一个对话', usage='!last')
class LastOperator(operator.CommandOperator): class LastOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
if context.session.conversations: if context.session.conversations:
# 找到当前会话的上一个会话 # 找到当前会话的上一个会话
for index in range(len(context.session.conversations) - 1, -1, -1): for index in range(len(context.session.conversations) - 1, -1, -1):
if context.session.conversations[index] == context.session.using_conversation: if context.session.conversations[index] == context.session.using_conversation:
if index == 0: if index == 0:
yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了')) yield command_context.CommandReturn(
error=command_errors.CommandOperationError('已经是第一个对话了')
)
return return
else: else:
context.session.using_conversation = context.session.conversations[index - 1] context.session.using_conversation = context.session.conversations[index - 1]
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S') time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
yield entities.CommandReturn( yield command_context.CommandReturn(
text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}' text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}'
) )
return return
else: else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))

View File

@@ -2,19 +2,22 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, errors from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
@operator.operator_class(name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>') @operator.operator_class(name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>')
class ListOperator(operator.CommandOperator): class ListOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
page = 0 page = 0
if len(context.crt_params) > 0: if len(context.crt_params) > 0:
try: try:
page = int(context.crt_params[0] - 1) page = int(context.crt_params[0] - 1)
except Exception: except Exception:
yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数')) yield command_context.CommandReturn(error=command_errors.CommandOperationError('页码应为整数'))
return return
record_per_page = 10 record_per_page = 10
@@ -45,4 +48,4 @@ class ListOperator(operator.CommandOperator):
else: else:
content += f'\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else "无内容"}' content += f'\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else "无内容"}'
yield entities.CommandReturn(text=f'{page + 1} 页 (时间倒序):\n{content}') yield command_context.CommandReturn(text=f'{page + 1} 页 (时间倒序):\n{content}')

View File

@@ -2,26 +2,31 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, errors from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
@operator.operator_class(name='next', help='切换到后一个对话', usage='!next') @operator.operator_class(name='next', help='切换到后一个对话', usage='!next')
class NextOperator(operator.CommandOperator): class NextOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
if context.session.conversations: if context.session.conversations:
# 找到当前会话的下一个会话 # 找到当前会话的下一个会话
for index in range(len(context.session.conversations)): for index in range(len(context.session.conversations)):
if context.session.conversations[index] == context.session.using_conversation: if context.session.conversations[index] == context.session.using_conversation:
if index == len(context.session.conversations) - 1: if index == len(context.session.conversations) - 1:
yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了')) yield command_context.CommandReturn(
error=command_errors.CommandOperationError('已经是最后一个对话了')
)
return return
else: else:
context.session.using_conversation = context.session.conversations[index + 1] context.session.using_conversation = context.session.conversations[index + 1]
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S') time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
yield entities.CommandReturn( yield command_context.CommandReturn(
text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}' text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}'
) )
return return
else: else:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))

View File

@@ -2,7 +2,8 @@ from __future__ import annotations
import typing import typing
import traceback import traceback
from .. import operator, entities, errors from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
@operator.operator_class( @operator.operator_class(
@@ -11,7 +12,9 @@ from .. import operator, entities, errors
usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>', usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>',
) )
class PluginOperator(operator.CommandOperator): class PluginOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
plugin_list = self.ap.plugin_mgr.plugins() plugin_list = self.ap.plugin_mgr.plugins()
reply_str = '所有插件({}):\n'.format(len(plugin_list)) reply_str = '所有插件({}):\n'.format(len(plugin_list))
idx = 0 idx = 0
@@ -27,32 +30,36 @@ class PluginOperator(operator.CommandOperator):
idx += 1 idx += 1
yield entities.CommandReturn(text=reply_str) yield command_context.CommandReturn(text=reply_str)
@operator.operator_class(name='get', help='安装插件', privilege=2, parent_class=PluginOperator) @operator.operator_class(name='get', help='安装插件', privilege=2, parent_class=PluginOperator)
class PluginGetOperator(operator.CommandOperator): class PluginGetOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址')) yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件仓库地址'))
else: else:
repo = context.crt_params[0] repo = context.crt_params[0]
yield entities.CommandReturn(text='正在安装插件...') yield command_context.CommandReturn(text='正在安装插件...')
try: try:
await self.ap.plugin_mgr.install_plugin(repo) await self.ap.plugin_mgr.install_plugin(repo)
yield entities.CommandReturn(text='插件安装成功,请重启程序以加载插件') yield command_context.CommandReturn(text='插件安装成功,请重启程序以加载插件')
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError('插件安装失败: ' + str(e))) yield command_context.CommandReturn(error=command_errors.CommandError('插件安装失败: ' + str(e)))
@operator.operator_class(name='update', help='更新插件', privilege=2, parent_class=PluginOperator) @operator.operator_class(name='update', help='更新插件', privilege=2, parent_class=PluginOperator)
class PluginUpdateOperator(operator.CommandOperator): class PluginUpdateOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
else: else:
plugin_name = context.crt_params[0] plugin_name = context.crt_params[0]
@@ -60,24 +67,26 @@ class PluginUpdateOperator(operator.CommandOperator):
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is not None: if plugin_container is not None:
yield entities.CommandReturn(text='正在更新插件...') yield command_context.CommandReturn(text='正在更新插件...')
await self.ap.plugin_mgr.update_plugin(plugin_name) await self.ap.plugin_mgr.update_plugin(plugin_name)
yield entities.CommandReturn(text='插件更新成功,请重启程序以加载插件') yield command_context.CommandReturn(text='插件更新成功,请重启程序以加载插件')
else: else:
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: 未找到插件')) yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: 未找到插件'))
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e))) yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e)))
@operator.operator_class(name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator) @operator.operator_class(name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator)
class PluginUpdateAllOperator(operator.CommandOperator): class PluginUpdateAllOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
try: try:
plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()] plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()]
if plugins: if plugins:
yield entities.CommandReturn(text='正在更新插件...') yield command_context.CommandReturn(text='正在更新插件...')
updated = [] updated = []
try: try:
for plugin_name in plugins: for plugin_name in plugins:
@@ -85,20 +94,22 @@ class PluginUpdateAllOperator(operator.CommandOperator):
updated.append(plugin_name) updated.append(plugin_name)
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e))) yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e)))
yield entities.CommandReturn(text='已更新插件: {}'.format(', '.join(updated))) yield command_context.CommandReturn(text='已更新插件: {}'.format(', '.join(updated)))
else: else:
yield entities.CommandReturn(text='没有可更新的插件') yield command_context.CommandReturn(text='没有可更新的插件')
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e))) yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e)))
@operator.operator_class(name='del', help='删除插件', privilege=2, parent_class=PluginOperator) @operator.operator_class(name='del', help='删除插件', privilege=2, parent_class=PluginOperator)
class PluginDelOperator(operator.CommandOperator): class PluginDelOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
else: else:
plugin_name = context.crt_params[0] plugin_name = context.crt_params[0]
@@ -106,51 +117,55 @@ class PluginDelOperator(operator.CommandOperator):
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name) plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
if plugin_container is not None: if plugin_container is not None:
yield entities.CommandReturn(text='正在删除插件...') yield command_context.CommandReturn(text='正在删除插件...')
await self.ap.plugin_mgr.uninstall_plugin(plugin_name) await self.ap.plugin_mgr.uninstall_plugin(plugin_name)
yield entities.CommandReturn(text='插件删除成功,请重启程序以加载插件') yield command_context.CommandReturn(text='插件删除成功,请重启程序以加载插件')
else: else:
yield entities.CommandReturn(error=errors.CommandError('插件删除失败: 未找到插件')) yield command_context.CommandReturn(error=command_errors.CommandError('插件删除失败: 未找到插件'))
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError('插件删除失败: ' + str(e))) yield command_context.CommandReturn(error=command_errors.CommandError('插件删除失败: ' + str(e)))
@operator.operator_class(name='on', help='启用插件', privilege=2, parent_class=PluginOperator) @operator.operator_class(name='on', help='启用插件', privilege=2, parent_class=PluginOperator)
class PluginEnableOperator(operator.CommandOperator): class PluginEnableOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
else: else:
plugin_name = context.crt_params[0] plugin_name = context.crt_params[0]
try: try:
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True): if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True):
yield entities.CommandReturn(text='已启用插件: {}'.format(plugin_name)) yield command_context.CommandReturn(text='已启用插件: {}'.format(plugin_name))
else: else:
yield entities.CommandReturn( yield command_context.CommandReturn(
error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name)) error=command_errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
) )
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e))) yield command_context.CommandReturn(error=command_errors.CommandError('插件状态修改失败: ' + str(e)))
@operator.operator_class(name='off', help='禁用插件', privilege=2, parent_class=PluginOperator) @operator.operator_class(name='off', help='禁用插件', privilege=2, parent_class=PluginOperator)
class PluginDisableOperator(operator.CommandOperator): class PluginDisableOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
if len(context.crt_params) == 0: if len(context.crt_params) == 0:
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称')) yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
else: else:
plugin_name = context.crt_params[0] plugin_name = context.crt_params[0]
try: try:
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False): if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False):
yield entities.CommandReturn(text='已禁用插件: {}'.format(plugin_name)) yield command_context.CommandReturn(text='已禁用插件: {}'.format(plugin_name))
else: else:
yield entities.CommandReturn( yield command_context.CommandReturn(
error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name)) error=command_errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
) )
except Exception as e: except Exception as e:
traceback.print_exc() traceback.print_exc()
yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e))) yield command_context.CommandReturn(error=command_errors.CommandError('插件状态修改失败: ' + str(e)))

View File

@@ -2,19 +2,22 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, errors from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
@operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt') @operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt')
class PromptOperator(operator.CommandOperator): class PromptOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
"""执行""" """执行"""
if context.session.using_conversation is None: if context.session.using_conversation is None:
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话')) yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
else: else:
reply_str = '当前对话所有内容:\n\n' reply_str = '当前对话所有内容:\n\n'
for msg in context.session.using_conversation.messages: for msg in context.session.using_conversation.messages:
reply_str += f'{msg.role}: {msg.content}\n' reply_str += f'{msg.role}: {msg.content}\n'
yield entities.CommandReturn(text=reply_str) yield command_context.CommandReturn(text=reply_str)

View File

@@ -2,15 +2,18 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities, errors from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
@operator.operator_class(name='resend', help='重发当前会话的最后一条消息', usage='!resend') @operator.operator_class(name='resend', help='重发当前会话的最后一条消息', usage='!resend')
class ResendOperator(operator.CommandOperator): class ResendOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
# 回滚到最后一条用户message前 # 回滚到最后一条用户message前
if context.session.using_conversation is None: if context.session.using_conversation is None:
yield entities.CommandReturn(error=errors.CommandError('当前没有对话')) yield command_context.CommandReturn(error=command_errors.CommandError('当前没有对话'))
else: else:
conv_msg = context.session.using_conversation.messages conv_msg = context.session.using_conversation.messages
@@ -23,4 +26,4 @@ class ResendOperator(operator.CommandOperator):
conv_msg.pop() conv_msg.pop()
# 不重发了,提示用户已删除就行了 # 不重发了,提示用户已删除就行了
yield entities.CommandReturn(text='已删除最后一次请求记录') yield command_context.CommandReturn(text='已删除最后一次请求记录')

View File

@@ -2,13 +2,16 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context
@operator.operator_class(name='reset', help='重置当前会话', usage='!reset') @operator.operator_class(name='reset', help='重置当前会话', usage='!reset')
class ResetOperator(operator.CommandOperator): class ResetOperator(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
"""执行""" """执行"""
context.session.using_conversation = None context.session.using_conversation = None
yield entities.CommandReturn(text='已重置当前会话') yield command_context.CommandReturn(text='已重置当前会话')

View File

@@ -1,11 +0,0 @@
from __future__ import annotations
import typing
from .. import operator, entities
@operator.operator_class(name='update', help='更新程序', usage='!update', privilege=2)
class UpdateCommand(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
yield entities.CommandReturn(text='不再支持通过命令更新,请查看 LangBot 文档。')

View File

@@ -2,12 +2,15 @@ from __future__ import annotations
import typing import typing
from .. import operator, entities from .. import operator
from langbot_plugin.api.entities.builtin.command import context as command_context
@operator.operator_class(name='version', help='显示版本信息', usage='!version') @operator.operator_class(name='version', help='显示版本信息', usage='!version')
class VersionCommand(operator.CommandOperator): class VersionCommand(operator.CommandOperator):
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]: async def execute(
self, context: command_context.ExecuteContext
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}' reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}'
try: try:
@@ -16,4 +19,4 @@ class VersionCommand(operator.CommandOperator):
except Exception: except Exception:
pass pass
yield entities.CommandReturn(text=reply_str.strip()) yield command_context.CommandReturn(text=reply_str.strip())

View File

@@ -3,7 +3,6 @@ from __future__ import annotations
import logging import logging
import asyncio import asyncio
import traceback import traceback
import sys
import os import os
from ..platform import botmgr as im_mgr from ..platform import botmgr as im_mgr
@@ -12,7 +11,7 @@ from ..provider.modelmgr import modelmgr as llm_model_mgr
from ..provider.tools import toolmgr as llm_tool_mgr from ..provider.tools import toolmgr as llm_tool_mgr
from ..config import manager as config_mgr from ..config import manager as config_mgr
from ..command import cmdmgr from ..command import cmdmgr
from ..plugin import manager as plugin_mgr from ..plugin import connector as plugin_connector
from ..pipeline import pool from ..pipeline import pool
from ..pipeline import controller, pipelinemgr from ..pipeline import controller, pipelinemgr
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
@@ -75,7 +74,7 @@ class Application:
# ========================= # =========================
plugin_mgr: plugin_mgr.PluginManager = None plugin_connector: plugin_connector.PluginRuntimeConnector = None
query_pool: pool.QueryPool = None query_pool: pool.QueryPool = None
@@ -117,7 +116,7 @@ class Application:
async def run(self): async def run(self):
try: try:
await self.plugin_mgr.initialize_plugins() await self.plugin_connector.initialize_plugins()
# 后续可能会允许动态重启其他任务 # 后续可能会允许动态重启其他任务
# 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程 # 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程
@@ -157,6 +156,9 @@ class Application:
self.logger.error(f'应用运行致命异常: {e}') self.logger.error(f'应用运行致命异常: {e}')
self.logger.debug(f'Traceback: {traceback.format_exc()}') self.logger.debug(f'Traceback: {traceback.format_exc()}')
def dispose(self):
self.plugin_connector.dispose()
async def print_web_access_info(self): async def print_web_access_info(self):
"""打印访问 webui 的提示""" """打印访问 webui 的提示"""
@@ -183,59 +185,3 @@ class Application:
""".strip() """.strip()
for line in tips.split('\n'): for line in tips.split('\n'):
self.logger.info(line) self.logger.info(line)
async def reload(
self,
scope: core_entities.LifecycleControlScope,
):
match scope:
case core_entities.LifecycleControlScope.PLATFORM.value:
self.logger.info('执行热重载 scope=' + scope)
await self.platform_mgr.shutdown()
self.platform_mgr = im_mgr.PlatformManager(self)
await self.platform_mgr.initialize()
self.task_mgr.create_task(
self.platform_mgr.run(),
name='platform-manager',
scopes=[
core_entities.LifecycleControlScope.APPLICATION,
core_entities.LifecycleControlScope.PLATFORM,
],
)
case core_entities.LifecycleControlScope.PLUGIN.value:
self.logger.info('执行热重载 scope=' + scope)
await self.plugin_mgr.destroy_plugins()
# 删除 sys.module 中所有的 plugins/* 下的模块
for mod in list(sys.modules.keys()):
if mod.startswith('plugins.'):
del sys.modules[mod]
self.plugin_mgr = plugin_mgr.PluginManager(self)
await self.plugin_mgr.initialize()
await self.plugin_mgr.initialize_plugins()
await self.plugin_mgr.load_plugins()
await self.plugin_mgr.initialize_plugins()
case core_entities.LifecycleControlScope.PROVIDER.value:
self.logger.info('执行热重载 scope=' + scope)
await self.tool_mgr.shutdown()
llm_model_mgr_inst = llm_model_mgr.ModelManager(self)
await llm_model_mgr_inst.initialize()
self.model_mgr = llm_model_mgr_inst
llm_session_mgr_inst = llm_session_mgr.SessionManager(self)
await llm_session_mgr_inst.initialize()
self.sess_mgr = llm_session_mgr_inst
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(self)
await llm_tool_mgr_inst.initialize()
self.tool_mgr = llm_tool_mgr_inst
case _:
pass

View File

@@ -51,8 +51,8 @@ async def main(loop: asyncio.AbstractEventLoop):
import signal import signal
def signal_handler(sig, frame): def signal_handler(sig, frame):
app_inst.dispose()
print('[Signal] 程序退出.') print('[Signal] 程序退出.')
# ap.shutdown()
os._exit(0) os._exit(0)
signal.signal(signal.SIGINT, signal_handler) signal.signal(signal.SIGINT, signal_handler)

View File

@@ -1,18 +1,6 @@
from __future__ import annotations from __future__ import annotations
import enum import enum
import typing
import datetime
import asyncio
import pydantic.v1 as pydantic
from ..provider import entities as llm_entities
from ..provider.modelmgr import requester
from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter
from ..platform.types import message as platform_message
from ..platform.types import events as platform_events
class LifecycleControlScope(enum.Enum): class LifecycleControlScope(enum.Enum):
@@ -20,157 +8,3 @@ class LifecycleControlScope(enum.Enum):
PLATFORM = 'platform' PLATFORM = 'platform'
PLUGIN = 'plugin' PLUGIN = 'plugin'
PROVIDER = 'provider' PROVIDER = 'provider'
class LauncherTypes(enum.Enum):
"""一个请求的发起者类型"""
PERSON = 'person'
"""私聊"""
GROUP = 'group'
"""群聊"""
class Query(pydantic.BaseModel):
"""一次请求的信息封装"""
query_id: int
"""请求ID添加进请求池时生成"""
launcher_type: LauncherTypes
"""会话类型platform处理阶段设置"""
launcher_id: typing.Union[int, str]
"""会话IDplatform处理阶段设置"""
sender_id: typing.Union[int, str]
"""发送者IDplatform处理阶段设置"""
message_event: platform_events.MessageEvent
"""事件platform收到的原始事件"""
message_chain: platform_message.MessageChain
"""消息链platform收到的原始消息链"""
bot_uuid: typing.Optional[str] = None
"""机器人UUID。"""
pipeline_uuid: typing.Optional[str] = None
"""流水线UUID。"""
pipeline_config: typing.Optional[dict[str, typing.Any]] = None
"""流水线配置,由 Pipeline 在运行开始时设置。"""
adapter: msadapter.MessagePlatformAdapter
"""消息平台适配器对象单个app中可能启用了多个消息平台适配器此对象表明发起此query的适配器"""
session: typing.Optional[Session] = None
"""会话对象,由前置处理器阶段设置"""
messages: typing.Optional[list[llm_entities.Message]] = []
"""历史消息列表,由前置处理器阶段设置"""
prompt: typing.Optional[llm_entities.Prompt] = None
"""情景预设内容,由前置处理器阶段设置"""
user_message: typing.Optional[llm_entities.Message] = None
"""此次请求的用户消息对象,由前置处理器阶段设置"""
variables: typing.Optional[dict[str, typing.Any]] = None
"""变量由前置处理器阶段设置。在prompt中嵌入或由 Runner 传递到 LLMOps 平台。"""
use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None
"""使用的对话模型,由前置处理器阶段设置"""
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器阶段设置"""
resp_messages: (
typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]]
) = []
"""由Process阶段生成的回复消息对象列表"""
resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None
"""回复消息链从resp_messages包装而得"""
# ======= 内部保留 =======
current_stage: typing.Optional['pkg.pipeline.pipelinemgr.StageInstContainer'] = None
"""当前所处阶段"""
class Config:
arbitrary_types_allowed = True
# ========== 插件可调用的 API请求 API ==========
def set_variable(self, key: str, value: typing.Any):
"""设置变量"""
if self.variables is None:
self.variables = {}
self.variables[key] = value
def get_variable(self, key: str) -> typing.Any:
"""获取变量"""
if self.variables is None:
return None
return self.variables.get(key)
def get_variables(self) -> dict[str, typing.Any]:
"""获取所有变量"""
if self.variables is None:
return {}
return self.variables
class Conversation(pydantic.BaseModel):
"""对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation但只有一个当前使用的 Conversation"""
prompt: llm_entities.Prompt
messages: list[llm_entities.Message]
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None
use_funcs: typing.Optional[list[tools_entities.LLMFunction]]
pipeline_uuid: str
"""流水线UUID。"""
bot_uuid: str
"""机器人UUID。"""
uuid: typing.Optional[str] = None
"""该对话的 uuid在创建时不会自动生成。而是当使用 Dify API 等由外部管理对话信息的服务时,用于绑定外部的会话。具体如何使用,取决于 Runner。"""
class Config:
arbitrary_types_allowed = True
class Session(pydantic.BaseModel):
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""
launcher_type: LauncherTypes
launcher_id: typing.Union[int, str]
sender_id: typing.Optional[typing.Union[int, str]] = 0
use_prompt_name: typing.Optional[str] = 'default'
using_conversation: typing.Optional[Conversation] = None
conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
semaphore: typing.Optional[asyncio.Semaphore] = None
"""当前会话的信号量,用于限制并发"""
class Config:
arbitrary_types_allowed = True

View File

@@ -1,10 +1,11 @@
from __future__ import annotations from __future__ import annotations
import asyncio
from .. import stage, app from .. import stage, app
from ...utils import version, proxy, announce from ...utils import version, proxy, announce
from ...pipeline import pool, controller, pipelinemgr from ...pipeline import pool, controller, pipelinemgr
from ...plugin import manager as plugin_mgr from ...plugin import connector as plugin_connector
from ...command import cmdmgr from ...command import cmdmgr
from ...provider.session import sessionmgr as llm_session_mgr from ...provider.session import sessionmgr as llm_session_mgr
from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr
@@ -59,10 +60,13 @@ class BuildAppStage(stage.BootingStage):
ap.persistence_mgr = persistence_mgr_inst ap.persistence_mgr = persistence_mgr_inst
await persistence_mgr_inst.initialize() await persistence_mgr_inst.initialize()
plugin_mgr_inst = plugin_mgr.PluginManager(ap) async def runtime_disconnect_callback(connector: plugin_connector.PluginRuntimeConnector) -> None:
await plugin_mgr_inst.initialize() await asyncio.sleep(3)
ap.plugin_mgr = plugin_mgr_inst await plugin_connector_inst.initialize()
await plugin_mgr_inst.load_plugins()
plugin_connector_inst = plugin_connector.PluginRuntimeConnector(ap, runtime_disconnect_callback)
await plugin_connector_inst.initialize()
ap.plugin_connector = plugin_connector_inst
cmd_mgr_inst = cmdmgr.CommandManager(ap) cmd_mgr_inst = cmdmgr.CommandManager(ap)
await cmd_mgr_inst.initialize() await cmd_mgr_inst.initialize()

View File

@@ -0,0 +1,22 @@
import sqlalchemy
from .base import Base
class BinaryStorage(Base):
"""Current for plugin use only"""
__tablename__ = 'binary_storages'
unique_key = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
key = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
owner_type = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
owner = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
value = sqlalchemy.Column(sqlalchemy.LargeBinary, nullable=False)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column(
sqlalchemy.DateTime,
nullable=False,
server_default=sqlalchemy.func.now(),
onupdate=sqlalchemy.func.now(),
)

View File

@@ -13,6 +13,8 @@ class PluginSetting(Base):
enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True) enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True)
priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0) priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict) config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict)
install_source = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, default='github')
install_info = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now()) created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
updated_at = sqlalchemy.Column( updated_at = sqlalchemy.Column(
sqlalchemy.DateTime, sqlalchemy.DateTime,

View File

@@ -44,6 +44,38 @@ class PersistenceManager:
await self.create_tables() await self.create_tables()
# run migrations
database_version = await self.execute_async(
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
)
database_version = int(database_version.fetchone()[1])
required_database_version = constants.required_database_version
if database_version < required_database_version:
migrations = migration.preregistered_db_migrations
migrations.sort(key=lambda x: x.number)
last_migration_number = database_version
for migration_cls in migrations:
migration_instance = migration_cls(self.ap)
if (
migration_instance.number > database_version
and migration_instance.number <= required_database_version
):
await migration_instance.upgrade()
await self.execute_async(
sqlalchemy.update(metadata.Metadata)
.where(metadata.Metadata.key == 'database_version')
.values({'value': str(migration_instance.number)})
)
last_migration_number = migration_instance.number
self.ap.logger.info(f'Migration {migration_instance.number} completed.')
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
async def create_tables(self): async def create_tables(self):
# create tables # create tables
async with self.get_db_engine().connect() as conn: async with self.get_db_engine().connect() as conn:
@@ -87,38 +119,6 @@ class PersistenceManager:
# ================================= # =================================
# run migrations
database_version = await self.execute_async(
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
)
database_version = int(database_version.fetchone()[1])
required_database_version = constants.required_database_version
if database_version < required_database_version:
migrations = migration.preregistered_db_migrations
migrations.sort(key=lambda x: x.number)
last_migration_number = database_version
for migration_cls in migrations:
migration_instance = migration_cls(self.ap)
if (
migration_instance.number > database_version
and migration_instance.number <= required_database_version
):
await migration_instance.upgrade()
await self.execute_async(
sqlalchemy.update(metadata.Metadata)
.where(metadata.Metadata.key == 'database_version')
.values({'value': str(migration_instance.number)})
)
last_migration_number = migration_instance.number
self.ap.logger.info(f'Migration {migration_instance.number} completed.')
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
async def execute_async(self, *args, **kwargs) -> sqlalchemy.engine.cursor.CursorResult: async def execute_async(self, *args, **kwargs) -> sqlalchemy.engine.cursor.CursorResult:
async with self.get_db_engine().connect() as conn: async with self.get_db_engine().connect() as conn:
result = await conn.execute(*args, **kwargs) result = await conn.execute(*args, **kwargs)
@@ -128,10 +128,13 @@ class PersistenceManager:
def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine: def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
return self.db.get_engine() return self.db.get_engine()
def serialize_model(self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base) -> dict: def serialize_model(
self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base, masked_columns: list[str] = []
) -> dict:
return { return {
column.name: getattr(data, column.name) column.name: getattr(data, column.name)
if not isinstance(getattr(data, column.name), (datetime.datetime)) if not isinstance(getattr(data, column.name), (datetime.datetime))
else getattr(data, column.name).isoformat() else getattr(data, column.name).isoformat()
for column in model.__table__.columns for column in model.__table__.columns
if column.name not in masked_columns
} }

View File

@@ -0,0 +1,20 @@
from .. import migration
@migration.migration_class(4)
class DBMigratePluginConfig(migration.DBMigration):
"""插件配置"""
async def upgrade(self):
"""升级"""
if 'plugin' not in self.ap.instance_config.data:
self.ap.instance_config.data['plugin'] = {
'runtime_ws_url': 'ws://localhost:5400/control/ws',
}
await self.ap.instance_config.dump_config()
async def downgrade(self):
"""降级"""
pass

View File

@@ -0,0 +1,32 @@
import sqlalchemy
from .. import migration
@migration.migration_class(5)
class DBMigratePluginInstallSource(migration.DBMigration):
"""插件安装来源"""
async def upgrade(self):
"""升级"""
# 查询表结构获取所有列名(异步执行 SQL
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text('PRAGMA table_info(plugin_settings);'))
# fetchall() 是同步方法,无需 await
columns = [row[1] for row in result.fetchall()]
# 检查并添加 install_source 列
if 'install_source' not in columns:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text(
"ALTER TABLE plugin_settings ADD COLUMN install_source VARCHAR(255) NOT NULL DEFAULT 'github'"
)
)
# 检查并添加 install_info 列
if 'install_info' not in columns:
await self.ap.persistence_mgr.execute_async(
sqlalchemy.text("ALTER TABLE plugin_settings ADD COLUMN install_info JSON NOT NULL DEFAULT '{}'")
)
async def downgrade(self):
"""降级"""
pass

View File

@@ -1,7 +1,7 @@
from __future__ import annotations from __future__ import annotations
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('BanSessionCheckStage') @stage.stage_class('BanSessionCheckStage')
@@ -14,7 +14,7 @@ class BanSessionCheckStage(stage.PipelineStage):
async def initialize(self, pipeline_config: dict): async def initialize(self, pipeline_config: dict):
pass pass
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
found = False found = False
mode = query.pipeline_config['trigger']['access-control']['mode'] mode = query.pipeline_config['trigger']['access-control']['mode']

View File

@@ -3,12 +3,11 @@ from __future__ import annotations
from ...core import app from ...core import app
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities
from . import filter as filter_model, entities as filter_entities from . import filter as filter_model, entities as filter_entities
from ...provider import entities as llm_entities from langbot_plugin.api.entities.builtin.provider import message as provider_message
from ...platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ...utils import importutil from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import filters from . import filters
importutil.import_modules_in_pkg(filters) importutil.import_modules_in_pkg(filters)
@@ -58,7 +57,7 @@ class ContentFilterStage(stage.PipelineStage):
async def _pre_process( async def _pre_process(
self, self,
message: str, message: str,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.StageProcessResult: ) -> entities.StageProcessResult:
"""请求llm前处理消息 """请求llm前处理消息
只要有一个不通过就不放行,只放行 PASS 的消息 只要有一个不通过就不放行,只放行 PASS 的消息
@@ -67,7 +66,7 @@ class ContentFilterStage(stage.PipelineStage):
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg': if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
if not message.strip(): if not message.strip():
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
else: else:
for filter in self.filter_chain: for filter in self.filter_chain:
if filter_entities.EnableStage.PRE in filter.enable_stages: if filter_entities.EnableStage.PRE in filter.enable_stages:
@@ -86,14 +85,14 @@ class ContentFilterStage(stage.PipelineStage):
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个 elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
message = result.replacement message = result.replacement
query.message_chain = platform_message.MessageChain(platform_message.Plain(message)) query.message_chain = platform_message.MessageChain([platform_message.Plain(text=message)])
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
async def _post_process( async def _post_process(
self, self,
message: str, message: str,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.StageProcessResult: ) -> entities.StageProcessResult:
"""请求llm后处理响应 """请求llm后处理响应
只要是 PASS 或者 MASKED 的就通过此 filter将其 replacement 设置为message进入下一个 filter 只要是 PASS 或者 MASKED 的就通过此 filter将其 replacement 设置为message进入下一个 filter
@@ -123,7 +122,7 @@ class ContentFilterStage(stage.PipelineStage):
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理""" """处理"""
if stage_inst_name == 'PreContentFilterStage': if stage_inst_name == 'PreContentFilterStage':
contain_non_text = False contain_non_text = False
@@ -142,7 +141,7 @@ class ContentFilterStage(stage.PipelineStage):
return await self._pre_process(str(query.message_chain).strip(), query) return await self._pre_process(str(query.message_chain).strip(), query)
elif stage_inst_name == 'PostContentFilterStage': elif stage_inst_name == 'PostContentFilterStage':
# 仅处理 query.resp_messages[-1].content 是 str 的情况 # 仅处理 query.resp_messages[-1].content 是 str 的情况
if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance( if isinstance(query.resp_messages[-1], provider_message.Message) and isinstance(
query.resp_messages[-1].content, str query.resp_messages[-1].content, str
): ):
return await self._post_process(query.resp_messages[-1].content, query) return await self._post_process(query.resp_messages[-1].content, query)

View File

@@ -1,6 +1,6 @@
import enum import enum
import pydantic.v1 as pydantic import pydantic
class ResultLevel(enum.Enum): class ResultLevel(enum.Enum):

View File

@@ -3,9 +3,9 @@ from __future__ import annotations
import abc import abc
import typing import typing
from ...core import app, entities as core_entities from ...core import app
from . import entities from . import entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_filters: list[typing.Type[ContentFilter]] = [] preregistered_filters: list[typing.Type[ContentFilter]] = []
@@ -60,7 +60,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def process(self, query: core_entities.Query, message: str = None, image_url=None) -> entities.FilterResult: async def process(self, query: pipeline_query.Query, message: str = None, image_url=None) -> entities.FilterResult:
"""处理消息 """处理消息
分为前后阶段,具体取决于 enable_stages 的值。 分为前后阶段,具体取决于 enable_stages 的值。

View File

@@ -4,8 +4,7 @@ import aiohttp
from .. import entities from .. import entities
from .. import filter as filter_model from .. import filter as filter_model
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}' BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}'
BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token' BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
@@ -27,7 +26,7 @@ class BaiduCloudExamine(filter_model.ContentFilter):
) as resp: ) as resp:
return (await resp.json())['access_token'] return (await resp.json())['access_token']
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.post( async with session.post(
BAIDU_EXAMINE_URL.format(await self._get_token()), BAIDU_EXAMINE_URL.format(await self._get_token()),

View File

@@ -3,7 +3,7 @@ import re
from .. import filter as filter_model from .. import filter as filter_model
from .. import entities from .. import entities
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@filter_model.filter_class('ban-word-filter') @filter_model.filter_class('ban-word-filter')
@@ -13,7 +13,7 @@ class BanWordFilter(filter_model.ContentFilter):
async def initialize(self): async def initialize(self):
pass pass
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
found = False found = False
for word in self.ap.sensitive_meta.data['words']: for word in self.ap.sensitive_meta.data['words']:

View File

@@ -3,7 +3,7 @@ import re
from .. import entities from .. import entities
from .. import filter as filter_model from .. import filter as filter_model
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@filter_model.filter_class('content-ignore') @filter_model.filter_class('content-ignore')
@@ -16,7 +16,7 @@ class ContentIgnore(filter_model.ContentFilter):
entities.EnableStage.PRE, entities.EnableStage.PRE,
] ]
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult: async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']: if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']: for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
if message.startswith(rule): if message.startswith(rule):

View File

@@ -3,7 +3,10 @@ from __future__ import annotations
import asyncio import asyncio
import traceback import traceback
from ..core import app, entities from ..core import app
from ..core import entities as core_entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class Controller: class Controller:
@@ -22,19 +25,19 @@ class Controller:
"""事件处理循环""" """事件处理循环"""
try: try:
while True: while True:
selected_query: entities.Query = None selected_query: pipeline_query.Query = None
# 取请求 # 取请求
async with self.ap.query_pool: async with self.ap.query_pool:
queries: list[entities.Query] = self.ap.query_pool.queries queries: list[pipeline_query.Query] = self.ap.query_pool.queries
for query in queries: for query in queries:
session = await self.ap.sess_mgr.get_session(query) session = await self.ap.sess_mgr.get_session(query)
self.ap.logger.debug(f'Checking query {query} session {session}') self.ap.logger.debug(f'Checking query {query} session {session}')
if not session.semaphore.locked(): if not session._semaphore.locked():
selected_query = query selected_query = query
await session.semaphore.acquire() await session._semaphore.acquire()
break break
@@ -46,7 +49,7 @@ class Controller:
if selected_query: if selected_query:
async def _process_query(selected_query: entities.Query): async def _process_query(selected_query: pipeline_query.Query):
async with self.semaphore: # 总并发上限 async with self.semaphore: # 总并发上限
# find pipeline # find pipeline
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one. # Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
@@ -59,7 +62,7 @@ class Controller:
await pipeline.run(selected_query) await pipeline.run(selected_query)
async with self.ap.query_pool: async with self.ap.query_pool:
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release() (await self.ap.sess_mgr.get_session(selected_query))._semaphore.release()
# 通知其他协程,有新的请求可以处理了 # 通知其他协程,有新的请求可以处理了
self.ap.query_pool.condition.notify_all() self.ap.query_pool.condition.notify_all()
@@ -68,8 +71,8 @@ class Controller:
kind='query', kind='query',
name=f'query-{selected_query.query_id}', name=f'query-{selected_query.query_id}',
scopes=[ scopes=[
entities.LifecycleControlScope.APPLICATION, core_entities.LifecycleControlScope.APPLICATION,
entities.LifecycleControlScope.PLATFORM, core_entities.LifecycleControlScope.PLATFORM,
], ],
) )

View File

@@ -3,10 +3,10 @@ from __future__ import annotations
import enum import enum
import typing import typing
import pydantic.v1 as pydantic import pydantic
from ..platform.types import message as platform_message
from ..core import entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.message as platform_message
class ResultType(enum.Enum): class ResultType(enum.Enum):
@@ -20,7 +20,7 @@ class ResultType(enum.Enum):
class StageProcessResult(pydantic.BaseModel): class StageProcessResult(pydantic.BaseModel):
result_type: ResultType result_type: ResultType
new_query: entities.Query new_query: pipeline_query.Query
user_notice: typing.Optional[ user_notice: typing.Optional[
typing.Union[ typing.Union[

View File

@@ -5,10 +5,9 @@ import traceback
from . import strategy from . import strategy
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ...platform.types import message as platform_message
from ...utils import importutil from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import strategies from . import strategies
importutil.import_modules_in_pkg(strategies) importutil.import_modules_in_pkg(strategies)
@@ -67,7 +66,7 @@ class LongTextProcessStage(stage.PipelineStage):
await self.strategy_impl.initialize() await self.strategy_impl.initialize()
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
# 检查是否包含非 Plain 组件 # 检查是否包含非 Plain 组件
contains_non_plain = False contains_non_plain = False

View File

@@ -3,9 +3,9 @@ from __future__ import annotations
from .. import strategy as strategy_model from .. import strategy as strategy_model
from ....core import entities as core_entities
from ....platform.types import message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.message as platform_message
ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay
Forward = platform_message.Forward Forward = platform_message.Forward
@@ -13,7 +13,7 @@ Forward = platform_message.Forward
@strategy_model.strategy_class('forward') @strategy_model.strategy_class('forward')
class ForwardComponentStrategy(strategy_model.LongTextStrategy): class ForwardComponentStrategy(strategy_model.LongTextStrategy):
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
display = ForwardMessageDiaplay( display = ForwardMessageDiaplay(
title='群聊的聊天记录', title='群聊的聊天记录',
brief='[聊天记录]', brief='[聊天记录]',

View File

@@ -8,10 +8,10 @@ import re
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
import functools import functools
from ....platform.types import message as platform_message
from .. import strategy as strategy_model from .. import strategy as strategy_model
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.builtin.platform.message as platform_message
@strategy_model.strategy_class('image') @strategy_model.strategy_class('image')
@@ -27,7 +27,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
encoding='utf-8', encoding='utf-8',
) )
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
img_path = self.text_to_image( img_path = self.text_to_image(
text_str=message, text_str=message,
save_as='temp/{}.png'.format(int(time.time())), save_as='temp/{}.png'.format(int(time.time())),
@@ -131,7 +131,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
text_str: str, text_str: str,
save_as='temp.png', save_as='temp.png',
width=800, width=800,
query: core_entities.Query = None, query: pipeline_query.Query = None,
): ):
text_str = text_str.replace('\t', ' ') text_str = text_str.replace('\t', ' ')

View File

@@ -4,8 +4,9 @@ import typing
from ...core import app from ...core import app
from ...core import entities as core_entities
from ...platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_strategies: list[typing.Type[LongTextStrategy]] = [] preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
@@ -49,7 +50,7 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]: async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
"""处理长文本 """处理长文本
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法 在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法

View File

@@ -1,10 +1,9 @@
from __future__ import annotations from __future__ import annotations
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities
from . import truncator from . import truncator
from ...utils import importutil from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import truncators from . import truncators
importutil.import_modules_in_pkg(truncators) importutil.import_modules_in_pkg(truncators)
@@ -29,7 +28,7 @@ class ConversationMessageTruncator(stage.PipelineStage):
else: else:
raise ValueError(f'未知的截断器: {use_method}') raise ValueError(f'未知的截断器: {use_method}')
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理""" """处理"""
query = await self.trun.truncate(query) query = await self.trun.truncate(query)

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
import typing import typing
import abc import abc
from ...core import entities as core_entities, app from ...core import app
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_truncators: list[typing.Type[Truncator]] = [] preregistered_truncators: list[typing.Type[Truncator]] = []
@@ -47,7 +47,7 @@ class Truncator(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
async def truncate(self, query: core_entities.Query) -> core_entities.Query: async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
"""截断 """截断
一般只需要操作query.messages也可以扩展操作query.prompt, query.user_message。 一般只需要操作query.messages也可以扩展操作query.prompt, query.user_message。

View File

@@ -1,14 +1,14 @@
from __future__ import annotations from __future__ import annotations
from .. import truncator from .. import truncator
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@truncator.truncator_class('round') @truncator.truncator_class('round')
class RoundTruncator(truncator.Truncator): class RoundTruncator(truncator.Truncator):
"""前文回合数阶段器""" """前文回合数阶段器"""
async def truncate(self, query: core_entities.Query) -> core_entities.Query: async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
"""截断""" """截断"""
max_round = query.pipeline_config['ai']['local-agent']['max-round'] max_round = query.pipeline_config['ai']['local-agent']['max-round']

View File

@@ -5,14 +5,18 @@ import traceback
import sqlalchemy import sqlalchemy
from ..core import app, entities from ..core import app
from . import entities as pipeline_entities from . import entities as pipeline_entities
from ..entity.persistence import pipeline as persistence_pipeline from ..entity.persistence import pipeline as persistence_pipeline
from . import stage from . import stage
from ..platform.types import message as platform_message, events as platform_events import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ..plugin import events import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.events as events
from ..utils import importutil from ..utils import importutil
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import ( from . import (
resprule, resprule,
bansess, bansess,
@@ -75,17 +79,17 @@ class RuntimePipeline:
self.pipeline_entity = pipeline_entity self.pipeline_entity = pipeline_entity
self.stage_containers = stage_containers self.stage_containers = stage_containers
async def run(self, query: entities.Query): async def run(self, query: pipeline_query.Query):
query.pipeline_config = self.pipeline_entity.config query.pipeline_config = self.pipeline_entity.config
await self.process_query(query) await self.process_query(query)
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult): async def _check_output(self, query: pipeline_query.Query, result: pipeline_entities.StageProcessResult):
"""检查输出""" """检查输出"""
if result.user_notice: if result.user_notice:
# 处理str类型 # 处理str类型
if isinstance(result.user_notice, str): if isinstance(result.user_notice, str):
result.user_notice = platform_message.MessageChain(platform_message.Plain(result.user_notice)) result.user_notice = platform_message.MessageChain([platform_message.Plain(text=result.user_notice)])
elif isinstance(result.user_notice, list): elif isinstance(result.user_notice, list):
result.user_notice = platform_message.MessageChain(*result.user_notice) result.user_notice = platform_message.MessageChain(*result.user_notice)
@@ -109,7 +113,7 @@ class RuntimePipeline:
async def _execute_from_stage( async def _execute_from_stage(
self, self,
stage_index: int, stage_index: int,
query: entities.Query, query: pipeline_query.Query,
): ):
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。 """从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
@@ -136,7 +140,7 @@ class RuntimePipeline:
while i < len(self.stage_containers): while i < len(self.stage_containers):
stage_container = self.stage_containers[i] stage_container = self.stage_containers[i]
query.current_stage = stage_container # 标记到 Query 对象里 query.current_stage_name = stage_container.inst_name # 标记到 Query 对象里
result = stage_container.inst.process(query, stage_container.inst_name) result = stage_container.inst.process(query, stage_container.inst_name)
@@ -169,26 +173,26 @@ class RuntimePipeline:
i += 1 i += 1
async def process_query(self, query: entities.Query): async def process_query(self, query: pipeline_query.Query):
"""处理请求""" """处理请求"""
try: try:
# ======== 触发 MessageReceived 事件 ======== # ======== 触发 MessageReceived 事件 ========
event_type = ( event_type = (
events.PersonMessageReceived events.PersonMessageReceived
if query.launcher_type == entities.LauncherTypes.PERSON if query.launcher_type == provider_session.LauncherTypes.PERSON
else events.GroupMessageReceived else events.GroupMessageReceived
) )
event_ctx = await self.ap.plugin_mgr.emit_event( event_obj = event_type(
event=event_type( query=query,
launcher_type=query.launcher_type.value, launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id, launcher_id=query.launcher_id,
sender_id=query.sender_id, sender_id=query.sender_id,
message_chain=query.message_chain, message_chain=query.message_chain,
query=query,
)
) )
event_ctx = await self.ap.plugin_connector.emit_event(event_obj)
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
return return
@@ -196,11 +200,12 @@ class RuntimePipeline:
await self._execute_from_stage(0, query) await self._execute_from_stage(0, query)
except Exception as e: except Exception as e:
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown' inst_name = query.current_stage_name if query.current_stage_name else 'unknown'
self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}') self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}')
self.ap.logger.error(f'Traceback: {traceback.format_exc()}') self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
finally: finally:
self.ap.logger.debug(f'Query {query} processed') self.ap.logger.debug(f'Query {query} processed')
del self.ap.query_pool.cached_queries[query.query_id]
class PipelineManager: class PipelineManager:

View File

@@ -3,10 +3,11 @@ from __future__ import annotations
import asyncio import asyncio
import typing import typing
from ..core import entities import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ..platform import adapter as msadapter import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ..platform.types import message as platform_message import langbot_plugin.api.entities.builtin.provider.session as provider_session
from ..platform.types import events as platform_events import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
class QueryPool: class QueryPool:
@@ -16,7 +17,10 @@ class QueryPool:
pool_lock: asyncio.Lock pool_lock: asyncio.Lock
queries: list[entities.Query] queries: list[pipeline_query.Query]
cached_queries: dict[int, pipeline_query.Query]
"""Cached queries, used for plugin backward api call, will be removed after the query completely processed"""
condition: asyncio.Condition condition: asyncio.Condition
@@ -24,34 +28,38 @@ class QueryPool:
self.query_id_counter = 0 self.query_id_counter = 0
self.pool_lock = asyncio.Lock() self.pool_lock = asyncio.Lock()
self.queries = [] self.queries = []
self.cached_queries = {}
self.condition = asyncio.Condition(self.pool_lock) self.condition = asyncio.Condition(self.pool_lock)
async def add_query( async def add_query(
self, self,
bot_uuid: str, bot_uuid: str,
launcher_type: entities.LauncherTypes, launcher_type: provider_session.LauncherTypes,
launcher_id: typing.Union[int, str], launcher_id: typing.Union[int, str],
sender_id: typing.Union[int, str], sender_id: typing.Union[int, str],
message_event: platform_events.MessageEvent, message_event: platform_events.MessageEvent,
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
adapter: msadapter.MessagePlatformAdapter, adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
pipeline_uuid: typing.Optional[str] = None, pipeline_uuid: typing.Optional[str] = None,
) -> entities.Query: ) -> pipeline_query.Query:
async with self.condition: async with self.condition:
query = entities.Query( query_id = self.query_id_counter
query = pipeline_query.Query(
bot_uuid=bot_uuid, bot_uuid=bot_uuid,
query_id=self.query_id_counter, query_id=query_id,
launcher_type=launcher_type, launcher_type=launcher_type,
launcher_id=launcher_id, launcher_id=launcher_id,
sender_id=sender_id, sender_id=sender_id,
message_event=message_event, message_event=message_event,
message_chain=message_chain, message_chain=message_chain,
variables={},
resp_messages=[], resp_messages=[],
resp_message_chain=[], resp_message_chain=[],
adapter=adapter, adapter=adapter,
pipeline_uuid=pipeline_uuid, pipeline_uuid=pipeline_uuid,
) )
self.queries.append(query) self.queries.append(query)
self.cached_queries[query_id] = query
self.query_id_counter += 1 self.query_id_counter += 1
self.condition.notify_all() self.condition.notify_all()

View File

@@ -3,10 +3,10 @@ from __future__ import annotations
import datetime import datetime
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities from langbot_plugin.api.entities.builtin.provider import message as provider_message
from ...provider import entities as llm_entities import langbot_plugin.api.entities.events as events
from ...plugin import events import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ...platform.types import message as platform_message import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('PreProcessor') @stage.stage_class('PreProcessor')
@@ -26,7 +26,7 @@ class PreProcessor(stage.PipelineStage):
async def process( async def process(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
stage_inst_name: str, stage_inst_name: str,
) -> entities.StageProcessResult: ) -> entities.StageProcessResult:
"""处理""" """处理"""
@@ -49,31 +49,31 @@ class PreProcessor(stage.PipelineStage):
query.bot_uuid, query.bot_uuid,
) )
conversation.use_llm_model = llm_model
# 设置query # 设置query
query.session = session query.session = session
query.prompt = conversation.prompt.copy() query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy() query.messages = conversation.messages.copy()
query.use_llm_model = llm_model query.use_llm_model_uuid = llm_model.model_entity.uuid
if selected_runner == 'local-agent': if selected_runner == 'local-agent':
query.use_funcs = ( query.use_funcs = []
conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('func_call') else None
)
query.variables = { if llm_model.model_entity.abilities.__contains__('func_call'):
query.use_funcs = await self.ap.tool_mgr.get_all_tools()
variables = {
'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}', 'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
'conversation_id': conversation.uuid, 'conversation_id': conversation.uuid,
'msg_create_time': ( 'msg_create_time': (
int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp()) int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp())
), ),
} }
query.variables.update(variables)
# Check if this model supports vision, if not, remove all images # Check if this model supports vision, if not, remove all images
# TODO this checking should be performed in runner, and in this stage, the image should be reserved # TODO this checking should be performed in runner, and in this stage, the image should be reserved
if selected_runner == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'): if selected_runner == 'local-agent' and not llm_model.model_entity.abilities.__contains__('vision'):
for msg in query.messages: for msg in query.messages:
if isinstance(msg.content, list): if isinstance(msg.content, list):
for me in msg.content: for me in msg.content:
@@ -87,39 +87,35 @@ class PreProcessor(stage.PipelineStage):
for me in query.message_chain: for me in query.message_chain:
if isinstance(me, platform_message.Plain): if isinstance(me, platform_message.Plain):
content_list.append(llm_entities.ContentElement.from_text(me.text)) content_list.append(provider_message.ContentElement.from_text(me.text))
plain_text += me.text plain_text += me.text
elif isinstance(me, platform_message.Image): elif isinstance(me, platform_message.Image):
if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__( if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
'vision'
):
if me.base64 is not None: if me.base64 is not None:
content_list.append(llm_entities.ContentElement.from_image_base64(me.base64)) content_list.append(provider_message.ContentElement.from_image_base64(me.base64))
elif isinstance(me, platform_message.Quote) and qoute_msg: elif isinstance(me, platform_message.Quote) and qoute_msg:
for msg in me.origin: for msg in me.origin:
if isinstance(msg, platform_message.Plain): if isinstance(msg, platform_message.Plain):
content_list.append(llm_entities.ContentElement.from_text(msg.text)) content_list.append(provider_message.ContentElement.from_text(msg.text))
elif isinstance(msg, platform_message.Image): elif isinstance(msg, platform_message.Image):
if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__( if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
'vision'
):
if msg.base64 is not None: if msg.base64 is not None:
content_list.append(llm_entities.ContentElement.from_image_base64(msg.base64)) content_list.append(provider_message.ContentElement.from_image_base64(msg.base64))
query.variables['user_message_text'] = plain_text query.variables['user_message_text'] = plain_text
query.user_message = llm_entities.Message(role='user', content=content_list) query.user_message = provider_message.Message(role='user', content=content_list)
# =========== 触发事件 PromptPreProcessing # =========== 触发事件 PromptPreProcessing
event_ctx = await self.ap.plugin_mgr.emit_event( event = events.PromptPreProcessing(
event=events.PromptPreProcessing( session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}', default_prompt=query.prompt.messages,
default_prompt=query.prompt.messages, prompt=query.messages,
prompt=query.messages, query=query,
query=query,
)
) )
event_ctx = await self.ap.plugin_connector.emit_event(event)
query.prompt.messages = event_ctx.event.default_prompt query.prompt.messages = event_ctx.event.default_prompt
query.messages = event_ctx.event.prompt query.messages = event_ctx.event.prompt

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
import abc import abc
from ...core import app from ...core import app
from ...core import entities as core_entities
from .. import entities from .. import entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
class MessageHandler(metaclass=abc.ABCMeta): class MessageHandler(metaclass=abc.ABCMeta):
@@ -19,7 +19,7 @@ class MessageHandler(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def handle( async def handle(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.StageProcessResult: ) -> entities.StageProcessResult:
raise NotImplementedError raise NotImplementedError

View File

@@ -6,13 +6,15 @@ import traceback
from .. import handler from .. import handler
from ... import entities from ... import entities
from ....core import entities as core_entities
from ....provider import runner as runner_module from ....provider import runner as runner_module
from ....plugin import events
from ....platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.events as events
from ....utils import importutil from ....utils import importutil
from ....provider import runners from ....provider import runners
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
importutil.import_modules_in_pkg(runners) importutil.import_modules_in_pkg(runners)
@@ -20,7 +22,7 @@ importutil.import_modules_in_pkg(runners)
class ChatMessageHandler(handler.MessageHandler): class ChatMessageHandler(handler.MessageHandler):
async def handle( async def handle(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]: ) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理""" """处理"""
# 调API # 调API
@@ -29,20 +31,20 @@ class ChatMessageHandler(handler.MessageHandler):
# 触发插件事件 # 触发插件事件
event_class = ( event_class = (
events.PersonNormalMessageReceived events.PersonNormalMessageReceived
if query.launcher_type == core_entities.LauncherTypes.PERSON if query.launcher_type == provider_session.LauncherTypes.PERSON
else events.GroupNormalMessageReceived else events.GroupNormalMessageReceived
) )
event_ctx = await self.ap.plugin_mgr.emit_event( event = event_class(
event=event_class( launcher_type=query.launcher_type.value,
launcher_type=query.launcher_type.value, launcher_id=query.launcher_id,
launcher_id=query.launcher_id, sender_id=query.sender_id,
sender_id=query.sender_id, text_message=str(query.message_chain),
text_message=str(query.message_chain), query=query,
query=query,
)
) )
event_ctx = await self.ap.plugin_connector.emit_event(event)
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
mc = platform_message.MessageChain(event_ctx.event.reply) mc = platform_message.MessageChain(event_ctx.event.reply)

View File

@@ -4,16 +4,17 @@ import typing
from .. import handler from .. import handler
from ... import entities from ... import entities
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.provider.message as provider_message
from ....provider import entities as llm_entities import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ....plugin import events import langbot_plugin.api.entities.builtin.provider.session as provider_session
from ....platform.types import message as platform_message import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.events as events
class CommandHandler(handler.MessageHandler): class CommandHandler(handler.MessageHandler):
async def handle( async def handle(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]: ) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理""" """处理"""
@@ -28,23 +29,23 @@ class CommandHandler(handler.MessageHandler):
event_class = ( event_class = (
events.PersonCommandSent events.PersonCommandSent
if query.launcher_type == core_entities.LauncherTypes.PERSON if query.launcher_type == provider_session.LauncherTypes.PERSON
else events.GroupCommandSent else events.GroupCommandSent
) )
event_ctx = await self.ap.plugin_mgr.emit_event( event = event_class(
event=event_class( launcher_type=query.launcher_type.value,
launcher_type=query.launcher_type.value, launcher_id=query.launcher_id,
launcher_id=query.launcher_id, sender_id=query.sender_id,
sender_id=query.sender_id, command=spt[0],
command=spt[0], params=spt[1:] if len(spt) > 1 else [],
params=spt[1:] if len(spt) > 1 else [], text_message=str(query.message_chain),
text_message=str(query.message_chain), is_admin=(privilege == 2),
is_admin=(privilege == 2), query=query,
query=query,
)
) )
event_ctx = await self.ap.plugin_connector.emit_event(event)
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
mc = platform_message.MessageChain(event_ctx.event.reply) mc = platform_message.MessageChain(event_ctx.event.reply)
@@ -64,7 +65,7 @@ class CommandHandler(handler.MessageHandler):
async for ret in self.ap.cmd_mgr.execute(command_text=command_text, query=query, session=session): async for ret in self.ap.cmd_mgr.execute(command_text=command_text, query=query, session=session):
if ret.error is not None: if ret.error is not None:
query.resp_messages.append( query.resp_messages.append(
llm_entities.Message( provider_message.Message(
role='command', role='command',
content=str(ret.error), content=str(ret.error),
) )
@@ -74,16 +75,16 @@ class CommandHandler(handler.MessageHandler):
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
elif ret.text is not None or ret.image_url is not None: elif ret.text is not None or ret.image_url is not None:
content: list[llm_entities.ContentElement] = [] content: list[provider_message.ContentElement] = []
if ret.text is not None: if ret.text is not None:
content.append(llm_entities.ContentElement.from_text(ret.text)) content.append(provider_message.ContentElement.from_text(ret.text))
if ret.image_url is not None: if ret.image_url is not None:
content.append(llm_entities.ContentElement.from_image_url(ret.image_url)) content.append(provider_message.ContentElement.from_image_url(ret.image_url))
query.resp_messages.append( query.resp_messages.append(
llm_entities.Message( provider_message.Message(
role='command', role='command',
content=content, content=content,
) )

View File

@@ -1,10 +1,10 @@
from __future__ import annotations from __future__ import annotations
from ...core import entities as core_entities
from . import handler from . import handler
from .handlers import chat, command from .handlers import chat, command
from .. import entities from .. import entities
from .. import stage from .. import stage
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('MessageProcessor') @stage.stage_class('MessageProcessor')
@@ -30,7 +30,7 @@ class Processor(stage.PipelineStage):
async def process( async def process(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
stage_inst_name: str, stage_inst_name: str,
) -> entities.StageProcessResult: ) -> entities.StageProcessResult:
"""处理""" """处理"""

View File

@@ -2,7 +2,8 @@ from __future__ import annotations
import abc import abc
import typing import typing
from ...core import app, entities as core_entities from ...core import app
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = [] preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
@@ -33,7 +34,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def require_access( async def require_access(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
launcher_type: str, launcher_type: str,
launcher_id: typing.Union[int, str], launcher_id: typing.Union[int, str],
) -> bool: ) -> bool:
@@ -53,7 +54,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def release_access( async def release_access(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
launcher_type: str, launcher_type: str,
launcher_id: typing.Union[int, str], launcher_id: typing.Union[int, str],
): ):

View File

@@ -3,7 +3,7 @@ import asyncio
import time import time
import typing import typing
from .. import algo from .. import algo
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
# 固定窗口算法 # 固定窗口算法
@@ -32,7 +32,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
async def require_access( async def require_access(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
launcher_type: str, launcher_type: str,
launcher_id: typing.Union[int, str], launcher_id: typing.Union[int, str],
) -> bool: ) -> bool:
@@ -91,7 +91,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
async def release_access( async def release_access(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
launcher_type: str, launcher_type: str,
launcher_id: typing.Union[int, str], launcher_id: typing.Union[int, str],
): ):

View File

@@ -4,9 +4,10 @@ import typing
from .. import entities, stage from .. import entities, stage
from . import algo from . import algo
from ...core import entities as core_entities
from ...utils import importutil from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import algos from . import algos
importutil.import_modules_in_pkg(algos) importutil.import_modules_in_pkg(algos)
@@ -39,7 +40,7 @@ class RateLimit(stage.PipelineStage):
async def process( async def process(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
stage_inst_name: str, stage_inst_name: str,
) -> typing.Union[ ) -> typing.Union[
entities.StageProcessResult, entities.StageProcessResult,

View File

@@ -4,18 +4,18 @@ import random
import asyncio import asyncio
from ...platform.types import events as platform_events import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ...platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('SendResponseBackStage') @stage.stage_class('SendResponseBackStage')
class SendResponseBackStage(stage.PipelineStage): class SendResponseBackStage(stage.PipelineStage):
"""发送响应消息""" """发送响应消息"""
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
"""处理""" """处理"""
random_range = ( random_range = (

View File

@@ -1,6 +1,6 @@
import pydantic.v1 as pydantic import pydantic
from ...platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
class RuleJudgeResult(pydantic.BaseModel): class RuleJudgeResult(pydantic.BaseModel):

View File

@@ -4,9 +4,10 @@ from __future__ import annotations
from . import rule from . import rule
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities
from ...utils import importutil from ...utils import importutil
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
from . import rules from . import rules
importutil.import_modules_in_pkg(rules) importutil.import_modules_in_pkg(rules)
@@ -32,7 +33,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
await rule_inst.initialize() await rule_inst.initialize()
self.rule_matchers.append(rule_inst) self.rule_matchers.append(rule_inst)
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult: async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
if query.launcher_type.value != 'group': # 只处理群消息 if query.launcher_type.value != 'group': # 只处理群消息
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -2,10 +2,11 @@ from __future__ import annotations
import abc import abc
import typing import typing
from ...core import app, entities as core_entities from ...core import app
from . import entities from . import entities
from ...platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregisetered_rules: list[typing.Type[GroupRespondRule]] = [] preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
@@ -39,7 +40,7 @@ class GroupRespondRule(metaclass=abc.ABCMeta):
message_text: str, message_text: str,
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
"""判断消息是否匹配规则""" """判断消息是否匹配规则"""
raise NotImplementedError raise NotImplementedError

View File

@@ -3,8 +3,8 @@ from __future__ import annotations
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ....platform.types import message as platform_message import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@rule_model.rule_class('at-bot') @rule_model.rule_class('at-bot')
@@ -14,7 +14,7 @@ class AtBotRule(rule_model.GroupRespondRule):
message_text: str, message_text: str,
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']: if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
message_chain.remove(platform_message.At(query.adapter.bot_account_id)) message_chain.remove(platform_message.At(query.adapter.bot_account_id))

View File

@@ -1,7 +1,7 @@
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ....platform.types import message as platform_message import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@rule_model.rule_class('prefix') @rule_model.rule_class('prefix')
@@ -11,7 +11,7 @@ class PrefixRule(rule_model.GroupRespondRule):
message_text: str, message_text: str,
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
prefixes = rule_dict['prefix'] prefixes = rule_dict['prefix']

View File

@@ -3,8 +3,8 @@ import random
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ....platform.types import message as platform_message import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@rule_model.rule_class('random') @rule_model.rule_class('random')
@@ -14,7 +14,7 @@ class RandomRespRule(rule_model.GroupRespondRule):
message_text: str, message_text: str,
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
random_rate = rule_dict['random'] random_rate = rule_dict['random']

View File

@@ -3,8 +3,8 @@ import re
from .. import rule as rule_model from .. import rule as rule_model
from .. import entities from .. import entities
from ....core import entities as core_entities import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ....platform.types import message as platform_message import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@rule_model.rule_class('regexp') @rule_model.rule_class('regexp')
@@ -14,7 +14,7 @@ class RegExpRule(rule_model.GroupRespondRule):
message_text: str, message_text: str,
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
rule_dict: dict, rule_dict: dict,
query: core_entities.Query, query: pipeline_query.Query,
) -> entities.RuleJudgeResult: ) -> entities.RuleJudgeResult:
regexps = rule_dict['regexp'] regexps = rule_dict['regexp']

View File

@@ -3,8 +3,9 @@ from __future__ import annotations
import abc import abc
import typing import typing
from ..core import app, entities as core_entities from ..core import app
from . import entities from . import entities
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
preregistered_stages: dict[str, type[PipelineStage]] = {} preregistered_stages: dict[str, type[PipelineStage]] = {}
@@ -33,7 +34,7 @@ class PipelineStage(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
async def process( async def process(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
stage_inst_name: str, stage_inst_name: str,
) -> typing.Union[ ) -> typing.Union[
entities.StageProcessResult, entities.StageProcessResult,

View File

@@ -2,12 +2,12 @@ from __future__ import annotations
import typing import typing
from ...core import entities as core_entities
from .. import entities from .. import entities
from .. import stage from .. import stage
from ...plugin import events
from ...platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
import langbot_plugin.api.entities.events as events
@stage.stage_class('ResponseWrapper') @stage.stage_class('ResponseWrapper')
@@ -25,7 +25,7 @@ class ResponseWrapper(stage.PipelineStage):
async def process( async def process(
self, self,
query: core_entities.Query, query: pipeline_query.Query,
stage_inst_name: str, stage_inst_name: str,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]: ) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理""" """处理"""
@@ -58,21 +58,22 @@ class ResponseWrapper(stage.PipelineStage):
reply_text = str(result.get_content_platform_message_chain()) reply_text = str(result.get_content_platform_message_chain())
# ============= 触发插件事件 =============== # ============= 触发插件事件 ===============
event_ctx = await self.ap.plugin_mgr.emit_event( event = events.NormalMessageResponded(
event=events.NormalMessageResponded( launcher_type=query.launcher_type.value,
launcher_type=query.launcher_type.value, launcher_id=query.launcher_id,
launcher_id=query.launcher_id, sender_id=query.sender_id,
sender_id=query.sender_id, session=session,
session=session, prefix='',
prefix='', response_text=reply_text,
response_text=reply_text, finish_reason='stop',
finish_reason='stop', funcs_called=[fc.function.name for fc in result.tool_calls]
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None
if result.tool_calls is not None else [],
else [], query=query,
query=query,
)
) )
event_ctx = await self.ap.plugin_connector.emit_event(event)
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, result_type=entities.ResultType.INTERRUPT,
@@ -96,26 +97,26 @@ class ResponseWrapper(stage.PipelineStage):
reply_text = f'调用函数 {".".join(function_names)}...' reply_text = f'调用函数 {".".join(function_names)}...'
query.resp_message_chain.append( query.resp_message_chain.append(
platform_message.MessageChain([platform_message.Plain(reply_text)]) platform_message.MessageChain([platform_message.Plain(text=reply_text)])
) )
if query.pipeline_config['output']['misc']['track-function-calls']: if query.pipeline_config['output']['misc']['track-function-calls']:
event_ctx = await self.ap.plugin_mgr.emit_event( event = events.NormalMessageResponded(
event=events.NormalMessageResponded( launcher_type=query.launcher_type.value,
launcher_type=query.launcher_type.value, launcher_id=query.launcher_id,
launcher_id=query.launcher_id, sender_id=query.sender_id,
sender_id=query.sender_id, session=session,
session=session, prefix='',
prefix='', response_text=reply_text,
response_text=reply_text, finish_reason='stop',
finish_reason='stop', funcs_called=[fc.function.name for fc in result.tool_calls]
funcs_called=[fc.function.name for fc in result.tool_calls] if result.tool_calls is not None
if result.tool_calls is not None else [],
else [], query=query,
query=query,
)
) )
event_ctx = await self.ap.plugin_connector.emit_event(event)
if event_ctx.is_prevented_default(): if event_ctx.is_prevented_default():
yield entities.StageProcessResult( yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT, result_type=entities.ResultType.INTERRUPT,
@@ -124,12 +125,12 @@ class ResponseWrapper(stage.PipelineStage):
else: else:
if event_ctx.event.reply is not None: if event_ctx.event.reply is not None:
query.resp_message_chain.append( query.resp_message_chain.append(
platform_message.MessageChain(event_ctx.event.reply) platform_message.MessageChain(text=event_ctx.event.reply)
) )
else: else:
query.resp_message_chain.append( query.resp_message_chain.append(
platform_message.MessageChain([platform_message.Plain(reply_text)]) platform_message.MessageChain([platform_message.Plain(text=reply_text)])
) )
yield entities.StageProcessResult( yield entities.StageProcessResult(

View File

@@ -1,160 +0,0 @@
from __future__ import annotations
# MessageSource的适配器
import typing
import abc
from ..core import app
from .types import message as platform_message
from .types import events as platform_events
from .logger import EventLogger
class MessagePlatformAdapter(metaclass=abc.ABCMeta):
"""消息平台适配器基类"""
name: str
bot_account_id: int
"""机器人账号ID需要在初始化时设置"""
config: dict
ap: app.Application
logger: EventLogger
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
"""初始化适配器
Args:
config (dict): 对应的配置
ap (app.Application): 应用上下文
"""
self.config = config
self.ap = ap
self.logger = logger
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
"""主动发送消息
Args:
target_type (str): 目标类型,`person`或`group`
target_id (str): 目标ID
message (platform.types.MessageChain): 消息链
"""
raise NotImplementedError
async def reply_message(
self,
message_source: platform_events.MessageEvent,
message: platform_message.MessageChain,
quote_origin: bool = False,
):
"""回复消息
Args:
message_source (platform.types.MessageEvent): 消息源事件
message (platform.types.MessageChain): 消息链
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
"""
raise NotImplementedError
async def is_muted(self, group_id: int) -> bool:
"""获取账号是否在指定群被禁言"""
raise NotImplementedError
def register_listener(
self,
event_type: typing.Type[platform_message.Event],
callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None],
):
"""注册事件监听器
Args:
event_type (typing.Type[platform.types.Event]): 事件类型
callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件
"""
raise NotImplementedError
def unregister_listener(
self,
event_type: typing.Type[platform_message.Event],
callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None],
):
"""注销事件监听器
Args:
event_type (typing.Type[platform.types.Event]): 事件类型
callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件
"""
raise NotImplementedError
async def run_async(self):
"""异步运行"""
raise NotImplementedError
async def kill(self) -> bool:
"""关闭适配器
Returns:
bool: 是否成功关闭热重载时若此函数返回False则不会重载MessageSource底层
"""
raise NotImplementedError
class MessageConverter:
"""消息链转换器基类"""
@staticmethod
def yiri2target(message_chain: platform_message.MessageChain):
"""将源平台消息链转换为目标平台消息链
Args:
message_chain (platform.types.MessageChain): 源平台消息链
Returns:
typing.Any: 目标平台消息链
"""
raise NotImplementedError
@staticmethod
def target2yiri(message_chain: typing.Any) -> platform_message.MessageChain:
"""将目标平台消息链转换为源平台消息链
Args:
message_chain (typing.Any): 目标平台消息链
Returns:
platform.types.MessageChain: 源平台消息链
"""
raise NotImplementedError
class EventConverter:
"""事件转换器基类"""
@staticmethod
def yiri2target(event: typing.Type[platform_message.Event]):
"""将源平台事件转换为目标平台事件
Args:
event (typing.Type[platform.types.Event]): 源平台事件
Returns:
typing.Any: 目标平台事件
"""
raise NotImplementedError
@staticmethod
def target2yiri(event: typing.Any) -> platform_message.Event:
"""将目标平台事件的调用参数转换为源平台的事件参数对象
Args:
event (typing.Any): 目标平台事件
Returns:
typing.Type[platform.types.Event]: 源平台事件
"""
raise NotImplementedError

View File

@@ -1,14 +0,0 @@
apiVersion: v1
kind: ComponentTemplate
metadata:
name: MessagePlatformAdapter
label:
en_US: Message Platform Adapter
zh_Hans: 消息平台适配器模板类
spec:
type:
- python
execution:
python:
path: ./adapter.py
attr: MessagePlatformAdapter

View File

@@ -1,15 +1,10 @@
from __future__ import annotations from __future__ import annotations
import sys
import asyncio import asyncio
import traceback import traceback
import sqlalchemy import sqlalchemy
# FriendMessage, Image, MessageChain, Plain
from . import adapter as msadapter
from ..core import app, entities as core_entities, taskmgr from ..core import app, entities as core_entities, taskmgr
from .types import events as platform_events, message as platform_message
from ..discover import engine from ..discover import engine
@@ -19,10 +14,10 @@ from ..entity.errors import platform as platform_errors
from .logger import EventLogger from .logger import EventLogger
# 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题 import langbot_plugin.api.entities.builtin.provider.session as provider_session
from . import types as mirai import langbot_plugin.api.entities.builtin.platform.events as platform_events
import langbot_plugin.api.entities.builtin.platform.message as platform_message
sys.modules['mirai'] = mirai import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
class RuntimeBot: class RuntimeBot:
@@ -34,7 +29,7 @@ class RuntimeBot:
enable: bool enable: bool
adapter: msadapter.MessagePlatformAdapter adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter
task_wrapper: taskmgr.TaskWrapper task_wrapper: taskmgr.TaskWrapper
@@ -46,7 +41,7 @@ class RuntimeBot:
self, self,
ap: app.Application, ap: app.Application,
bot_entity: persistence_bot.Bot, bot_entity: persistence_bot.Bot,
adapter: msadapter.MessagePlatformAdapter, adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
logger: EventLogger, logger: EventLogger,
): ):
self.ap = ap self.ap = ap
@@ -59,7 +54,7 @@ class RuntimeBot:
async def initialize(self): async def initialize(self):
async def on_friend_message( async def on_friend_message(
event: platform_events.FriendMessage, event: platform_events.FriendMessage,
adapter: msadapter.MessagePlatformAdapter, adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
): ):
image_components = [ image_components = [
component for component in event.message_chain if isinstance(component, platform_message.Image) component for component in event.message_chain if isinstance(component, platform_message.Image)
@@ -73,7 +68,7 @@ class RuntimeBot:
await self.ap.query_pool.add_query( await self.ap.query_pool.add_query(
bot_uuid=self.bot_entity.uuid, bot_uuid=self.bot_entity.uuid,
launcher_type=core_entities.LauncherTypes.PERSON, launcher_type=provider_session.LauncherTypes.PERSON,
launcher_id=event.sender.id, launcher_id=event.sender.id,
sender_id=event.sender.id, sender_id=event.sender.id,
message_event=event, message_event=event,
@@ -84,7 +79,7 @@ class RuntimeBot:
async def on_group_message( async def on_group_message(
event: platform_events.GroupMessage, event: platform_events.GroupMessage,
adapter: msadapter.MessagePlatformAdapter, adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
): ):
image_components = [ image_components = [
component for component in event.message_chain if isinstance(component, platform_message.Image) component for component in event.message_chain if isinstance(component, platform_message.Image)
@@ -98,7 +93,7 @@ class RuntimeBot:
await self.ap.query_pool.add_query( await self.ap.query_pool.add_query(
bot_uuid=self.bot_entity.uuid, bot_uuid=self.bot_entity.uuid,
launcher_type=core_entities.LauncherTypes.GROUP, launcher_type=provider_session.LauncherTypes.GROUP,
launcher_id=event.group.id, launcher_id=event.group.id,
sender_id=event.sender.id, sender_id=event.sender.id,
message_event=event, message_event=event,
@@ -151,7 +146,7 @@ class PlatformManager:
adapter_components: list[engine.Component] adapter_components: list[engine.Component]
adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] adapter_dict: dict[str, type[abstract_platform_adapter.AbstractMessagePlatformAdapter]]
def __init__(self, ap: app.Application = None): def __init__(self, ap: app.Application = None):
self.ap = ap self.ap = ap
@@ -161,7 +156,7 @@ class PlatformManager:
async def initialize(self): async def initialize(self):
self.adapter_components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter') self.adapter_components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter')
adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] = {} adapter_dict: dict[str, type[abstract_platform_adapter.AbstractMessagePlatformAdapter]] = {}
for component in self.adapter_components: for component in self.adapter_components:
adapter_dict[component.metadata.name] = component.get_python_component_class() adapter_dict[component.metadata.name] = component.get_python_component_class()
self.adapter_dict = adapter_dict self.adapter_dict = adapter_dict
@@ -172,9 +167,10 @@ class PlatformManager:
webchat_logger = EventLogger(name='webchat-adapter', ap=self.ap) webchat_logger = EventLogger(name='webchat-adapter', ap=self.ap)
webchat_adapter_inst = webchat_adapter_class( webchat_adapter_inst = webchat_adapter_class(
{}, {},
self.ap,
webchat_logger, webchat_logger,
ap=self.ap,
) )
webchat_adapter_inst.ap = self.ap
self.webchat_proxy_bot = RuntimeBot( self.webchat_proxy_bot = RuntimeBot(
ap=self.ap, ap=self.ap,
@@ -193,7 +189,7 @@ class PlatformManager:
await self.load_bots_from_db() await self.load_bots_from_db()
def get_running_adapters(self) -> list[msadapter.MessagePlatformAdapter]: def get_running_adapters(self) -> list[abstract_platform_adapter.AbstractMessagePlatformAdapter]:
return [bot.adapter for bot in self.bots if bot.enable] return [bot.adapter for bot in self.bots if bot.enable]
async def load_bots_from_db(self): async def load_bots_from_db(self):
@@ -231,7 +227,6 @@ class PlatformManager:
adapter_inst = self.adapter_dict[bot_entity.adapter]( adapter_inst = self.adapter_dict[bot_entity.adapter](
bot_entity.adapter_config, bot_entity.adapter_config,
self.ap,
logger, logger,
) )
@@ -274,43 +269,6 @@ class PlatformManager:
return component return component
return None return None
async def write_back_config(
self,
adapter_name: str,
adapter_inst: msadapter.MessagePlatformAdapter,
config: dict,
):
# index = -2
# for i, adapter in enumerate(self.adapters):
# if adapter == adapter_inst:
# index = i
# break
# if index == -2:
# raise Exception('平台适配器未找到')
# # 只修改启用的适配器
# real_index = -1
# for i, adapter in enumerate(self.ap.platform_cfg.data['platform-adapters']):
# if adapter['enable']:
# index -= 1
# if index == -1:
# real_index = i
# break
# new_cfg = {
# 'adapter': adapter_name,
# 'enable': True,
# **config
# }
# self.ap.platform_cfg.data['platform-adapters'][real_index] = new_cfg
# await self.ap.platform_cfg.dump_config()
# TODO implement this
pass
async def run(self): async def run(self):
# This method will only be called when the application launching # This method will only be called when the application launching
await self.webchat_proxy_bot.run() await self.webchat_proxy_bot.run()

View File

@@ -9,7 +9,8 @@ import traceback
import uuid import uuid
from ..core import app from ..core import app
from .types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_event_logger
class EventLogLevel(enum.Enum): class EventLogLevel(enum.Enum):
@@ -55,7 +56,7 @@ MAX_LOG_COUNT = 200
DELETE_COUNT_PER_TIME = 50 DELETE_COUNT_PER_TIME = 50
class EventLogger: class EventLogger(abstract_platform_event_logger.AbstractEventLogger):
"""used for logging bot events""" """used for logging bot events"""
ap: app.Application ap: app.Application

View File

@@ -5,18 +5,17 @@ import traceback
import datetime import datetime
import aiocqhttp import aiocqhttp
import pydantic
from .. import adapter import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from ...core import app import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ..types import message as platform_message import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ..types import events as platform_events import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from ..types import entities as platform_entities
from ...utils import image from ...utils import image
from ..logger import EventLogger import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
class AiocqhttpMessageConverter(adapter.MessageConverter): class AiocqhttpMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
@staticmethod @staticmethod
async def yiri2target( async def yiri2target(
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
@@ -71,15 +70,13 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
elif msg.face_type=='dice': elif msg.face_type=='dice':
msg_list.append(aiocqhttp.MessageSegment.dice()) msg_list.append(aiocqhttp.MessageSegment.dice())
else: else:
msg_list.append(aiocqhttp.MessageSegment.text(str(msg))) msg_list.append(aiocqhttp.MessageSegment.text(str(msg)))
return msg_list, msg_id, msg_time return msg_list, msg_id, msg_time
@staticmethod @staticmethod
async def target2yiri(message: str, message_id: int = -1,bot=None): async def target2yiri(message: str, message_id: int = -1, bot=None):
print(message)
message = aiocqhttp.Message(message) message = aiocqhttp.Message(message)
def get_face_name(face_id): def get_face_name(face_id):
@@ -119,30 +116,28 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
return face_code_dict.get(face_id,'') return face_code_dict.get(face_id,'')
async def process_message_data(msg_data, reply_list): async def process_message_data(msg_data, reply_list):
if msg_data["type"] == "image": if msg_data['type'] == 'image':
image_base64, image_format = await image.qq_image_url_to_base64(msg_data["data"]['url']) image_base64, image_format = await image.qq_image_url_to_base64(msg_data['data']['url'])
reply_list.append( reply_list.append(platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}'))
platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}'))
elif msg_data["type"] == "text": elif msg_data['type'] == 'text':
reply_list.append(platform_message.Plain(text=msg_data["data"]["text"])) reply_list.append(platform_message.Plain(text=msg_data['data']['text']))
elif msg_data["type"] == "forward": # 这里来应该传入转发消息组暂时传入qoute elif msg_data['type'] == 'forward': # 这里来应该传入转发消息组暂时传入qoute
for forward_msg_datas in msg_data["data"]["content"]: for forward_msg_datas in msg_data['data']['content']:
for forward_msg_data in forward_msg_datas["message"]: for forward_msg_data in forward_msg_datas['message']:
await process_message_data(forward_msg_data, reply_list) await process_message_data(forward_msg_data, reply_list)
elif msg_data["type"] == "at": elif msg_data['type'] == 'at':
if msg_data["data"]['qq'] == 'all': if msg_data['data']['qq'] == 'all':
reply_list.append(platform_message.AtAll()) reply_list.append(platform_message.AtAll())
else: else:
reply_list.append( reply_list.append(
platform_message.At( platform_message.At(
target=msg_data["data"]['qq'], target=msg_data['data']['qq'],
) )
) )
yiri_msg_list = [] yiri_msg_list = []
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now())) yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
@@ -178,14 +173,15 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
# await process_message_data(msg_data, yiri_msg_list) # await process_message_data(msg_data, yiri_msg_list)
pass pass
elif msg.type == 'reply': # 此处处理引用消息传入Qoute elif msg.type == 'reply': # 此处处理引用消息传入Qoute
msg_datas = await bot.get_msg(message_id=msg.data["id"]) msg_datas = await bot.get_msg(message_id=msg.data['id'])
for msg_data in msg_datas["message"]: for msg_data in msg_datas['message']:
await process_message_data(msg_data, reply_list) await process_message_data(msg_data, reply_list)
reply_msg = platform_message.Quote(message_id=msg.data["id"],sender_id=msg_datas["user_id"],origin=reply_list) reply_msg = platform_message.Quote(
message_id=msg.data['id'], sender_id=msg_datas['user_id'], origin=reply_list
)
yiri_msg_list.append(reply_msg) yiri_msg_list.append(reply_msg)
elif msg.type == 'file': elif msg.type == 'file':
@@ -194,6 +190,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
file_data = await bot.get_file(file_id=file_id) file_data = await bot.get_file(file_id=file_id)
file_name = file_data.get('file_name') file_name = file_data.get('file_name')
file_path = file_data.get('file') file_path = file_data.get('file')
_ = file_path
file_url = file_data.get('file_url') file_url = file_data.get('file_url')
file_size = file_data.get('file_size') file_size = file_data.get('file_size')
yiri_msg_list.append(platform_message.File(id=file_id, name=file_name,url=file_url,size=file_size)) yiri_msg_list.append(platform_message.File(id=file_id, name=file_name,url=file_url,size=file_size))
@@ -210,32 +207,19 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
face_id = msg.data['result'] face_id = msg.data['result']
yiri_msg_list.append(platform_message.Face(face_type='dice',face_id=int(face_id),face_name='骰子')) yiri_msg_list.append(platform_message.Face(face_type='dice',face_id=int(face_id),face_name='骰子'))
chain = platform_message.MessageChain(yiri_msg_list) chain = platform_message.MessageChain(yiri_msg_list)
return chain return chain
class AiocqhttpEventConverter(abstract_platform_adapter.AbstractEventConverter):
class AiocqhttpEventConverter(adapter.EventConverter):
@staticmethod @staticmethod
async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int): async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int):
return event.source_platform_object return event.source_platform_object
@staticmethod @staticmethod
async def target2yiri(event: aiocqhttp.Event,bot=None): async def target2yiri(event: aiocqhttp.Event, bot=None):
yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id,bot) yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id, bot)
if event.message_type == 'group': if event.message_type == 'group':
@@ -279,23 +263,19 @@ class AiocqhttpEventConverter(adapter.EventConverter):
) )
class AiocqhttpAdapter(adapter.MessagePlatformAdapter): class AiocqhttpAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: aiocqhttp.CQHttp bot: aiocqhttp.CQHttp = pydantic.Field(exclude=True, default_factory=aiocqhttp.CQHttp)
bot_account_id: int
message_converter: AiocqhttpMessageConverter = AiocqhttpMessageConverter() message_converter: AiocqhttpMessageConverter = AiocqhttpMessageConverter()
event_converter: AiocqhttpEventConverter = AiocqhttpEventConverter() event_converter: AiocqhttpEventConverter = AiocqhttpEventConverter()
config: dict
ap: app.Application
on_websocket_connection_event_cache: typing.List[typing.Callable[[aiocqhttp.Event], None]] = [] on_websocket_connection_event_cache: typing.List[typing.Callable[[aiocqhttp.Event], None]] = []
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
self.config = config super().__init__(
self.logger = logger config=config,
logger=logger,
)
async def shutdown_trigger_placeholder(): async def shutdown_trigger_placeholder():
while True: while True:
@@ -303,7 +283,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
self.config['shutdown_trigger'] = shutdown_trigger_placeholder self.config['shutdown_trigger'] = shutdown_trigger_placeholder
self.ap = ap
self.on_websocket_connection_event_cache = [] self.on_websocket_connection_event_cache = []
if 'access-token' in config: if 'access-token' in config:
@@ -316,7 +295,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0] aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0]
if target_type == 'group': if target_type == 'group':
await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg) await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg)
elif target_type == 'person': elif target_type == 'person':
await self.bot.send_private_msg(user_id=int(target_id), message=aiocq_msg) await self.bot.send_private_msg(user_id=int(target_id), message=aiocq_msg)
@@ -340,12 +318,14 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
async def on_message(event: aiocqhttp.Event): async def on_message(event: aiocqhttp.Event):
self.bot_account_id = event.self_id self.bot_account_id = event.self_id
try: try:
return await callback(await self.event_converter.target2yiri(event,self.bot), self) return await callback(await self.event_converter.target2yiri(event, self.bot), self)
except Exception: except Exception:
await self.logger.error(f'Error in on_message: {traceback.format_exc()}') await self.logger.error(f'Error in on_message: {traceback.format_exc()}')
traceback.print_exc() traceback.print_exc()
@@ -371,7 +351,9 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
return super().unregister_listener(event_type, callback) return super().unregister_listener(event_type, callback)

View File

@@ -1,18 +1,16 @@
import traceback import traceback
import typing import typing
from libs.dingtalk_api.dingtalkevent import DingTalkEvent from libs.dingtalk_api.dingtalkevent import DingTalkEvent
from pkg.platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
from pkg.platform.adapter import MessagePlatformAdapter import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from .. import adapter import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ...core import app import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from ..types import events as platform_events
from ..types import entities as platform_entities
from libs.dingtalk_api.api import DingTalkClient from libs.dingtalk_api.api import DingTalkClient
import datetime import datetime
from ..logger import EventLogger from ..logger import EventLogger
class DingTalkMessageConverter(adapter.MessageConverter): class DingTalkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
@staticmethod @staticmethod
async def yiri2target(message_chain: platform_message.MessageChain): async def yiri2target(message_chain: platform_message.MessageChain):
content = '' content = ''
@@ -48,7 +46,7 @@ class DingTalkMessageConverter(adapter.MessageConverter):
return chain return chain
class DingTalkEventConverter(adapter.EventConverter): class DingTalkEventConverter(abstract_platform_adapter.AbstractEventConverter):
@staticmethod @staticmethod
async def yiri2target(event: platform_events.MessageEvent): async def yiri2target(event: platform_events.MessageEvent):
return event.source_platform_object return event.source_platform_object
@@ -92,17 +90,15 @@ class DingTalkEventConverter(adapter.EventConverter):
) )
class DingTalkAdapter(adapter.MessagePlatformAdapter): class DingTalkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: DingTalkClient bot: DingTalkClient
ap: app.Application
bot_account_id: str bot_account_id: str
message_converter: DingTalkMessageConverter = DingTalkMessageConverter() message_converter: DingTalkMessageConverter = DingTalkMessageConverter()
event_converter: DingTalkEventConverter = DingTalkEventConverter() event_converter: DingTalkEventConverter = DingTalkEventConverter()
config: dict config: dict
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: EventLogger):
self.config = config self.config = config
self.ap = ap
self.logger = logger self.logger = logger
required_keys = [ required_keys = [
'client_id', 'client_id',
@@ -140,7 +136,9 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
async def on_message(event: DingTalkEvent): async def on_message(event: DingTalkEvent):
try: try:
@@ -174,6 +172,8 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
async def unregister_listener( async def unregister_listener(
self, self,
event_type: type, event_type: type,
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
return super().unregister_listener(event_type, callback) return super().unregister_listener(event_type, callback)

View File

@@ -8,19 +8,18 @@ import base64
import uuid import uuid
import os import os
import datetime import datetime
import io
import aiohttp import aiohttp
import pydantic
from .. import adapter import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from ...core import app import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ..types import message as platform_message import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ..types import events as platform_events import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from ..types import entities as platform_entities import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
from ..logger import EventLogger
class DiscordMessageConverter(adapter.MessageConverter): class DiscordMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
@staticmethod @staticmethod
async def yiri2target( async def yiri2target(
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
@@ -36,88 +35,28 @@ class DiscordMessageConverter(adapter.MessageConverter):
for ele in message_chain: for ele in message_chain:
if isinstance(ele, platform_message.Image): if isinstance(ele, platform_message.Image):
image_bytes = None image_bytes = None
filename = f'{uuid.uuid4()}.png' # 默认文件名
if ele.base64: if ele.base64:
# 处理base64编码的图片 image_bytes = base64.b64decode(ele.base64)
if ele.base64.startswith('data:'):
# 从data URL中提取文件类型
data_header = ele.base64.split(',')[0]
if 'jpeg' in data_header or 'jpg' in data_header:
filename = f'{uuid.uuid4()}.jpg'
elif 'gif' in data_header:
filename = f'{uuid.uuid4()}.gif'
elif 'webp' in data_header:
filename = f'{uuid.uuid4()}.webp'
# 去掉data:image/xxx;base64,前缀
base64_data = ele.base64.split(',')[1]
else:
base64_data = ele.base64
image_bytes = base64.b64decode(base64_data)
elif ele.url: elif ele.url:
# 从URL下载图片
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(ele.url) as response: async with session.get(ele.url) as response:
image_bytes = await response.read() image_bytes = await response.read()
# 从URL或Content-Type推断文件类型
content_type = response.headers.get('Content-Type', '')
if 'jpeg' in content_type or 'jpg' in content_type:
filename = f'{uuid.uuid4()}.jpg'
elif 'gif' in content_type:
filename = f'{uuid.uuid4()}.gif'
elif 'webp' in content_type:
filename = f'{uuid.uuid4()}.webp'
elif ele.url.lower().endswith(('.jpg', '.jpeg')):
filename = f'{uuid.uuid4()}.jpg'
elif ele.url.lower().endswith('.gif'):
filename = f'{uuid.uuid4()}.gif'
elif ele.url.lower().endswith('.webp'):
filename = f'{uuid.uuid4()}.webp'
elif ele.path: elif ele.path:
# 从文件路径读取图片 with open(ele.path, 'rb') as f:
# 确保路径没有空字节 image_bytes = f.read()
clean_path = ele.path.replace('\x00', '')
clean_path = os.path.abspath(clean_path)
if not os.path.exists(clean_path):
continue # 跳过不存在的文件
try:
with open(clean_path, 'rb') as f:
image_bytes = f.read()
# 从文件路径获取文件名,保持原始扩展名
original_filename = os.path.basename(clean_path)
if original_filename and '.' in original_filename:
# 保持原始文件名的扩展名
ext = original_filename.split('.')[-1].lower()
filename = f'{uuid.uuid4()}.{ext}'
else:
# 如果没有扩展名,尝试从文件内容检测
if image_bytes.startswith(b'\xff\xd8\xff'):
filename = f'{uuid.uuid4()}.jpg'
elif image_bytes.startswith(b'GIF'):
filename = f'{uuid.uuid4()}.gif'
elif image_bytes.startswith(b'RIFF') and b'WEBP' in image_bytes[:20]:
filename = f'{uuid.uuid4()}.webp'
# 默认保持PNG
except Exception as e:
print(f"Error reading image file {clean_path}: {e}")
continue # 跳过读取失败的文件
if image_bytes: image_files.append(discord.File(fp=image_bytes, filename=f'{uuid.uuid4()}.png'))
# 使用BytesIO创建文件对象避免路径问题
import io
image_files.append(discord.File(fp=io.BytesIO(image_bytes), filename=filename))
elif isinstance(ele, platform_message.Plain): elif isinstance(ele, platform_message.Plain):
text_string += ele.text text_string += ele.text
elif isinstance(ele, platform_message.Forward): elif isinstance(ele, platform_message.Forward):
for node in ele.node_list: for node in ele.node_list:
( (
node_text, text_string,
node_images, image_files,
) = await DiscordMessageConverter.yiri2target(node.message_chain) ) = await DiscordMessageConverter.yiri2target(node.message_chain)
text_string += node_text text_string += text_string
image_files.extend(node_images) image_files.extend(image_files)
return text_string, image_files return text_string, image_files
@@ -173,7 +112,7 @@ class DiscordMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain(element_list) return platform_message.MessageChain(element_list)
class DiscordEventConverter(adapter.EventConverter): class DiscordEventConverter(abstract_platform_adapter.AbstractEventConverter):
@staticmethod @staticmethod
async def yiri2target(event: platform_events.Event) -> discord.Message: async def yiri2target(event: platform_events.Event) -> discord.Message:
pass pass
@@ -215,29 +154,21 @@ class DiscordEventConverter(adapter.EventConverter):
) )
class DiscordAdapter(adapter.MessagePlatformAdapter): class DiscordAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: discord.Client bot: discord.Client = pydantic.Field(exclude=True)
bot_account_id: str # 用于在流水线中识别at是否是本bot直接以bot_name作为标识
config: dict
ap: app.Application
message_converter: DiscordMessageConverter = DiscordMessageConverter() message_converter: DiscordMessageConverter = DiscordMessageConverter()
event_converter: DiscordEventConverter = DiscordEventConverter() event_converter: DiscordEventConverter = DiscordEventConverter()
listeners: typing.Dict[ listeners: typing.Dict[
typing.Type[platform_events.Event], typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
] = {} ] = {}
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
self.config = config bot_account_id = config['client_id']
self.ap = ap
self.logger = logger
self.bot_account_id = self.config['client_id'] listeners = {}
adapter_self = self adapter_self = self
@@ -257,30 +188,18 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
if os.getenv('http_proxy'): if os.getenv('http_proxy'):
args['proxy'] = os.getenv('http_proxy') args['proxy'] = os.getenv('http_proxy')
self.bot = MyClient(intents=intents, **args) bot = MyClient(intents=intents, **args)
super().__init__(
config=config,
logger=logger,
bot_account_id=bot_account_id,
listeners=listeners,
bot=bot,
)
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
msg_to_send, image_files = await self.message_converter.yiri2target(message) pass
try:
# 获取频道对象
channel = self.bot.get_channel(int(target_id))
if channel is None:
# 如果本地缓存中没有尝试从API获取
channel = await self.bot.fetch_channel(int(target_id))
args = {
'content': msg_to_send,
}
if len(image_files) > 0:
args['files'] = image_files
await channel.send(**args)
except Exception as e:
await self.logger.error(f"Discord send_message failed: {e}")
raise e
async def reply_message( async def reply_message(
self, self,
@@ -312,14 +231,18 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
self.listeners[event_type] = callback self.listeners[event_type] = callback
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
self.listeners.pop(event_type) self.listeners.pop(event_type)

View File

@@ -17,13 +17,13 @@ import aiohttp
import lark_oapi.ws.exception import lark_oapi.ws.exception
import quart import quart
from lark_oapi.api.im.v1 import * from lark_oapi.api.im.v1 import *
import pydantic
from .. import adapter import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from ...core import app import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ..types import message as platform_message import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ..types import events as platform_events import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from ..types import entities as platform_entities import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
from ..logger import EventLogger
class AESCipher(object): class AESCipher(object):
@@ -52,7 +52,7 @@ class AESCipher(object):
return self.decrypt(enc).decode('utf8') return self.decrypt(enc).decode('utf8')
class LarkMessageConverter(adapter.MessageConverter): class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
@staticmethod @staticmethod
async def yiri2target( async def yiri2target(
message_chain: platform_message.MessageChain, api_client: lark_oapi.Client message_chain: platform_message.MessageChain, api_client: lark_oapi.Client
@@ -276,7 +276,7 @@ class LarkMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain(lb_msg_list) return platform_message.MessageChain(lb_msg_list)
class LarkEventConverter(adapter.EventConverter): class LarkEventConverter(abstract_platform_adapter.AbstractEventConverter):
@staticmethod @staticmethod
async def yiri2target( async def yiri2target(
event: platform_events.MessageEvent, event: platform_events.MessageEvent,
@@ -320,39 +320,28 @@ class LarkEventConverter(adapter.EventConverter):
) )
class LarkAdapter(adapter.MessagePlatformAdapter): class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: lark_oapi.ws.Client bot: lark_oapi.ws.Client = pydantic.Field(exclude=True)
api_client: lark_oapi.Client api_client: lark_oapi.Client = pydantic.Field(exclude=True)
bot_account_id: str # 用于在流水线中识别at是否是本bot直接以bot_name作为标识
lark_tenant_key: str # 飞书企业key
message_converter: LarkMessageConverter = LarkMessageConverter() message_converter: LarkMessageConverter = LarkMessageConverter()
event_converter: LarkEventConverter = LarkEventConverter() event_converter: LarkEventConverter = LarkEventConverter()
listeners: typing.Dict[ listeners: typing.Dict[
typing.Type[platform_events.Event], typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
] ]
config: dict quart_app: quart.Quart = pydantic.Field(exclude=True)
quart_app: quart.Quart
ap: app.Application
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
self.config = config quart_app = quart.Quart(__name__)
self.ap = ap
self.logger = logger
self.quart_app = quart.Quart(__name__)
self.listeners = {}
@self.quart_app.route('/lark/callback', methods=['POST']) @quart_app.route('/lark/callback', methods=['POST'])
async def lark_callback(): async def lark_callback():
try: try:
data = await quart.request.json data = await quart.request.json
self.ap.logger.debug(f'Lark callback event: {data}')
if 'encrypt' in data: if 'encrypt' in data:
cipher = AESCipher(self.config['encrypt-key']) cipher = AESCipher(self.config['encrypt-key'])
data = cipher.decrypt_string(data['encrypt']) data = cipher.decrypt_string(data['encrypt'])
@@ -378,15 +367,15 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
if 'im.message.receive_v1' == type: if 'im.message.receive_v1' == type:
try: try:
event = await self.event_converter.target2yiri(p2v1, self.api_client) event = await self.event_converter.target2yiri(p2v1, self.api_client)
except Exception as e: except Exception:
await self.logger.error(f"Error in lark callback: {traceback.format_exc()}") await self.logger.error(f'Error in lark callback: {traceback.format_exc()}')
if event.__class__ in self.listeners: if event.__class__ in self.listeners:
await self.listeners[event.__class__](event, self) await self.listeners[event.__class__](event, self)
return {'code': 200, 'message': 'ok'} return {'code': 200, 'message': 'ok'}
except Exception as e: except Exception:
await self.logger.error(f"Error in lark callback: {traceback.format_exc()}") await self.logger.error(f'Error in lark callback: {traceback.format_exc()}')
return {'code': 500, 'message': 'error'} return {'code': 500, 'message': 'error'}
async def on_message(event: lark_oapi.im.v1.P2ImMessageReceiveV1): async def on_message(event: lark_oapi.im.v1.P2ImMessageReceiveV1):
@@ -401,10 +390,20 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
lark_oapi.EventDispatcherHandler.builder('', '').register_p2_im_message_receive_v1(sync_on_message).build() lark_oapi.EventDispatcherHandler.builder('', '').register_p2_im_message_receive_v1(sync_on_message).build()
) )
self.bot_account_id = config['bot_name'] bot_account_id = config['bot_name']
self.bot = lark_oapi.ws.Client(config['app_id'], config['app_secret'], event_handler=event_handler) bot = lark_oapi.ws.Client(config['app_id'], config['app_secret'], event_handler=event_handler)
self.api_client = lark_oapi.Client.builder().app_id(config['app_id']).app_secret(config['app_secret']).build() api_client = lark_oapi.Client.builder().app_id(config['app_id']).app_secret(config['app_secret']).build()
super().__init__(
config=config,
logger=logger,
listeners={},
quart_app=quart_app,
bot=bot,
api_client=api_client,
bot_account_id=bot_account_id,
)
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
pass pass
@@ -453,14 +452,18 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
self.listeners[event_type] = callback self.listeners[event_type] = callback
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
self.listeners.pop(event_type) self.listeners.pop(event_type)

View File

Before

Width:  |  Height:  |  Size: 25 KiB

After

Width:  |  Height:  |  Size: 25 KiB

View File

@@ -11,19 +11,19 @@ import threading
import quart import quart
import aiohttp import aiohttp
from .. import adapter import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from ...core import app from ....core import app
from ..types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ..types import events as platform_events import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ..types import entities as platform_entities import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from ...utils import image from ....utils import image
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from typing import Optional, Tuple from typing import Optional, Tuple
from functools import partial from functools import partial
from ..logger import EventLogger from ...logger import EventLogger
class GewechatMessageConverter(adapter.MessageConverter): class GewechatMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
def __init__(self, config: dict): def __init__(self, config: dict):
self.config = config self.config = config
@@ -398,7 +398,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
return from_user_name.endswith('@chatroom') return from_user_name.endswith('@chatroom')
class GewechatEventConverter(adapter.EventConverter): class GewechatEventConverter(abstract_platform_adapter.AbstractEventConverter):
def __init__(self, config: dict): def __init__(self, config: dict):
self.config = config self.config = config
self.message_converter = GewechatMessageConverter(config) self.message_converter = GewechatMessageConverter(config)
@@ -458,7 +458,7 @@ class GewechatEventConverter(adapter.EventConverter):
) )
class GeWeChatAdapter(adapter.MessagePlatformAdapter): class GeWeChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
name: str = 'gewechat' # 定义适配器名称 name: str = 'gewechat' # 定义适配器名称
bot: gewechat_client.GewechatClient bot: gewechat_client.GewechatClient
@@ -475,7 +475,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
listeners: typing.Dict[ listeners: typing.Dict[
typing.Type[platform_events.Event], typing.Type[platform_events.Event],
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
] = {} ] = {}
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
@@ -491,7 +491,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
async def gewechat_callback(): async def gewechat_callback():
data = await quart.request.json data = await quart.request.json
# print(json.dumps(data, indent=4, ensure_ascii=False)) # print(json.dumps(data, indent=4, ensure_ascii=False))
self.ap.logger.debug(f'Gewechat callback event: {data}') await self.logger.debug(f'Gewechat callback event: {data}')
if 'data' in data: if 'data' in data:
data['Data'] = data['data'] data['Data'] = data['data']
@@ -601,7 +601,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
if handler := handler_map.get(msg['type']): if handler := handler_map.get(msg['type']):
handler(msg) handler(msg)
else: else:
self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}') await self.logger.warning(f'未处理的消息类型: {msg["type"]}')
continue continue
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
@@ -625,14 +625,18 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
self.listeners[event_type] = callback self.listeners[event_type] = callback
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
pass pass
@@ -656,9 +660,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
self.config['app_id'] = app_id self.config['app_id'] = app_id
self.ap.logger.info(f'Gewechat 登录成功app_id: {app_id}') print(f'Gewechat 登录成功app_id: {app_id}')
self.ap.platform_mgr.write_back_config('gewechat', self, self.config)
# 获取 nickname # 获取 nickname
profile = self.bot.get_profile(self.config['app_id']) profile = self.bot.get_profile(self.config['app_id'])

View File

Before

Width:  |  Height:  |  Size: 274 KiB

After

Width:  |  Height:  |  Size: 274 KiB

View File

@@ -9,15 +9,15 @@ import traceback
import nakuru import nakuru
import nakuru.entities.components as nkc import nakuru.entities.components as nkc
from .. import adapter as adapter_model from ....pipeline.longtext.strategies import forward
from ...pipeline.longtext.strategies import forward import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ...platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from ...platform.types import entities as platform_entities import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ...platform.types import events as platform_events import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from ..logger import EventLogger from ...logger import EventLogger
class NakuruProjectMessageConverter(adapter_model.MessageConverter): class NakuruProjectMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
"""消息转换器""" """消息转换器"""
@staticmethod @staticmethod
@@ -72,8 +72,9 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
content=content_list, content=content_list,
) )
nakuru_forward_node_list.append(nakuru_forward_node) nakuru_forward_node_list.append(nakuru_forward_node)
except Exception as e: except Exception:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
nakuru_msg_list.append(nakuru_forward_node_list) nakuru_msg_list.append(nakuru_forward_node_list)
@@ -108,7 +109,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
return chain return chain
class NakuruProjectEventConverter(adapter_model.EventConverter): class NakuruProjectEventConverter(abstract_platform_adapter.AbstractEventConverter):
"""事件转换器""" """事件转换器"""
@staticmethod @staticmethod
@@ -163,7 +164,7 @@ class NakuruProjectEventConverter(adapter_model.EventConverter):
raise Exception('未支持转换的事件类型: ' + str(event)) raise Exception('未支持转换的事件类型: ' + str(event))
class NakuruAdapter(adapter_model.MessagePlatformAdapter): class NakuruAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
"""nakuru-project适配器""" """nakuru-project适配器"""
bot: nakuru.CQHTTP bot: nakuru.CQHTTP
@@ -255,13 +256,15 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
try: try:
source_cls = NakuruProjectEventConverter.yiri2target(event_type) source_cls = NakuruProjectEventConverter.yiri2target(event_type)
# 包装函数 # 包装函数
async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls): async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls): # type: ignore
await callback(self.event_converter.target2yiri(source), self) await callback(self.event_converter.target2yiri(source), self)
# 将包装函数和原函数的对应关系存入列表 # 将包装函数和原函数的对应关系存入列表
@@ -276,13 +279,15 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
# 注册监听器 # 注册监听器
self.bot.receiver(source_cls.__name__)(listener_wrapper) self.bot.receiver(source_cls.__name__)(listener_wrapper)
except Exception as e: except Exception as e:
self.logger.error(f"Error in nakuru register_listener: {traceback.format_exc()}") self.logger.error(f'Error in nakuru register_listener: {traceback.format_exc()}')
raise e raise e
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
nakuru_event_name = self.event_converter.yiri2target(event_type).__name__ nakuru_event_name = self.event_converter.yiri2target(event_type).__name__
@@ -321,7 +326,6 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
except Exception: except Exception:
raise Exception('获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确') raise Exception('获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确')
await self.bot._run() await self.bot._run()
self.ap.logger.info('运行 Nakuru 适配器')
while True: while True:
await asyncio.sleep(1) await asyncio.sleep(1)

View File

@@ -10,14 +10,14 @@ import botpy
import botpy.message as botpy_message import botpy.message as botpy_message
import botpy.types.message as botpy_message_type import botpy.types.message as botpy_message_type
from .. import adapter as adapter_model import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
from ...pipeline.longtext.strategies import forward from ....pipeline.longtext.strategies import forward
from ...core import app from ....core import app
from ...config import manager as cfg_mgr from ....config import manager as cfg_mgr
from ...platform.types import entities as platform_entities import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
from ...platform.types import events as platform_events import langbot_plugin.api.entities.builtin.platform.events as platform_events
from ...platform.types import message as platform_message import langbot_plugin.api.entities.builtin.platform.message as platform_message
from ..logger import EventLogger from ...logger import EventLogger
class OfficialGroupMessage(platform_events.GroupMessage): class OfficialGroupMessage(platform_events.GroupMessage):
@@ -133,7 +133,7 @@ class OpenIDMapping(typing.Generic[K, V]):
return value return value
class OfficialMessageConverter(adapter_model.MessageConverter): class OfficialMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
"""QQ 官方消息转换器""" """QQ 官方消息转换器"""
@staticmethod @staticmethod
@@ -237,7 +237,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
return chain return chain
class OfficialEventConverter(adapter_model.EventConverter): class OfficialEventConverter(abstract_platform_adapter.AbstractEventConverter):
"""事件转换器""" """事件转换器"""
def __init__(self): def __init__(self):
@@ -333,7 +333,7 @@ class OfficialEventConverter(adapter_model.EventConverter):
) )
class OfficialAdapter(adapter_model.MessagePlatformAdapter): class OfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
"""QQ 官方消息适配器""" """QQ 官方消息适配器"""
bot: botpy.Client = None bot: botpy.Client = None
@@ -484,7 +484,9 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
try: try:
@@ -501,13 +503,15 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
for event_handler in event_handler_mapping[event_type]: for event_handler in event_handler_mapping[event_type]:
setattr(self.bot, event_handler, wrapper) setattr(self.bot, event_handler, wrapper)
except Exception as e: except Exception as e:
self.logger.error(f"Error in qqbotpy callback: {traceback.format_exc()}") self.logger.error(f'Error in qqbotpy callback: {traceback.format_exc()}')
raise e raise e
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None], callback: typing.Callable[
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
],
): ):
delattr(self.bot, event_handler_mapping[event_type]) delattr(self.bot, event_handler_mapping[event_type])
@@ -519,7 +523,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
self.cfg['ret_coro'] = True self.cfg['ret_coro'] = True
self.ap.logger.info('运行 QQ 官方适配器') await self.logger.info('运行 QQ 官方适配器')
await (await self.bot.start(**self.cfg)) await (await self.bot.start(**self.cfg))
async def kill(self) -> bool: async def kill(self) -> bool:

Some files were not shown because too many files have changed in this diff Show More