mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-26 03:44:58 +08:00
Compare commits
53 Commits
v4.0.8
...
v4.3.0.bet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d3a147bbdd | ||
|
|
8eb1b8759b | ||
|
|
0155d3b0b9 | ||
|
|
e47a5b4e0d | ||
|
|
4012310d99 | ||
|
|
9e9bc88473 | ||
|
|
53ade384eb | ||
|
|
8b2480ad3b | ||
|
|
b176959836 | ||
|
|
a0c42a5f6e | ||
|
|
17d997c88e | ||
|
|
0ea7609ff1 | ||
|
|
28d4b1dd61 | ||
|
|
5179b3e53a | ||
|
|
288b294148 | ||
|
|
b464d238c5 | ||
|
|
e1a78e8ff9 | ||
|
|
2b8eb5f01c | ||
|
|
bf2bc70794 | ||
|
|
ebe0b68e8f | ||
|
|
39c50d3c12 | ||
|
|
621f1301b3 | ||
|
|
0b60ef0d06 | ||
|
|
41650b585a | ||
|
|
f5b893cfe0 | ||
|
|
e0abd19636 | ||
|
|
4380041c7f | ||
|
|
65814a4644 | ||
|
|
7237294008 | ||
|
|
214bc8ada9 | ||
|
|
6a1de889b4 | ||
|
|
4a319b2b20 | ||
|
|
9f269d1614 | ||
|
|
4b57771eb1 | ||
|
|
5922be7e15 | ||
|
|
10a44c70b6 | ||
|
|
5b044a1917 | ||
|
|
a60aa6f644 | ||
|
|
1a10b40b17 | ||
|
|
e2124054bf | ||
|
|
ee3da8aa17 | ||
|
|
c246470b37 | ||
|
|
f474e42b79 | ||
|
|
5553a86ac8 | ||
|
|
01613b2f0d | ||
|
|
a177786063 | ||
|
|
62b2884011 | ||
|
|
6b782f8761 | ||
|
|
0c2560cafb | ||
|
|
c5eeab2fd0 | ||
|
|
6f2fd72af6 | ||
|
|
2d06f1cadb | ||
|
|
af493c117c |
@@ -119,7 +119,6 @@ docker compose up -d
|
|||||||
| [Anthropic](https://www.anthropic.com/) | ✅ | |
|
| [Anthropic](https://www.anthropic.com/) | ✅ | |
|
||||||
| [xAI](https://x.ai/) | ✅ | |
|
| [xAI](https://x.ai/) | ✅ | |
|
||||||
| [智谱AI](https://open.bigmodel.cn/) | ✅ | |
|
| [智谱AI](https://open.bigmodel.cn/) | ✅ | |
|
||||||
| [优云智算](https://www.compshare.cn/) | ✅ | 大模型和 GPU 资源平台 |
|
|
||||||
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | ✅ | 大模型和 GPU 资源平台 |
|
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | ✅ | 大模型和 GPU 资源平台 |
|
||||||
| [302 AI](https://share.302.ai/SuTG99) | ✅ | 大模型聚合平台 |
|
| [302 AI](https://share.302.ai/SuTG99) | ✅ | 大模型聚合平台 |
|
||||||
| [Google Gemini](https://aistudio.google.com/prompts/new_chat) | ✅ | |
|
| [Google Gemini](https://aistudio.google.com/prompts/new_chat) | ✅ | |
|
||||||
|
|||||||
@@ -116,7 +116,6 @@ Directly use the released version to run, see the [Manual Deployment](https://do
|
|||||||
| [Anthropic](https://www.anthropic.com/) | ✅ | |
|
| [Anthropic](https://www.anthropic.com/) | ✅ | |
|
||||||
| [xAI](https://x.ai/) | ✅ | |
|
| [xAI](https://x.ai/) | ✅ | |
|
||||||
| [Zhipu AI](https://open.bigmodel.cn/) | ✅ | |
|
| [Zhipu AI](https://open.bigmodel.cn/) | ✅ | |
|
||||||
| [CompShare](https://www.compshare.cn/) | ✅ | LLM and GPU resource platform |
|
|
||||||
| [Dify](https://dify.ai) | ✅ | LLMOps platform |
|
| [Dify](https://dify.ai) | ✅ | LLMOps platform |
|
||||||
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | ✅ | LLM and GPU resource platform |
|
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | ✅ | LLM and GPU resource platform |
|
||||||
| [302 AI](https://share.302.ai/SuTG99) | ✅ | LLM gateway(MaaS) |
|
| [302 AI](https://share.302.ai/SuTG99) | ✅ | LLM gateway(MaaS) |
|
||||||
|
|||||||
@@ -115,7 +115,6 @@ LangBotはBTPanelにリストされています。BTPanelをインストール
|
|||||||
| [Anthropic](https://www.anthropic.com/) | ✅ | |
|
| [Anthropic](https://www.anthropic.com/) | ✅ | |
|
||||||
| [xAI](https://x.ai/) | ✅ | |
|
| [xAI](https://x.ai/) | ✅ | |
|
||||||
| [Zhipu AI](https://open.bigmodel.cn/) | ✅ | |
|
| [Zhipu AI](https://open.bigmodel.cn/) | ✅ | |
|
||||||
| [CompShare](https://www.compshare.cn/) | ✅ | 大模型とGPUリソースプラットフォーム |
|
|
||||||
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | ✅ | 大模型とGPUリソースプラットフォーム |
|
| [PPIO](https://ppinfra.com/user/register?invited_by=QJKFYD&utm_source=github_langbot) | ✅ | 大模型とGPUリソースプラットフォーム |
|
||||||
| [302 AI](https://share.302.ai/SuTG99) | ✅ | LLMゲートウェイ(MaaS) |
|
| [302 AI](https://share.302.ai/SuTG99) | ✅ | LLMゲートウェイ(MaaS) |
|
||||||
| [Google Gemini](https://aistudio.google.com/prompts/new_chat) | ✅ | |
|
| [Google Gemini](https://aistudio.google.com/prompts/new_chat) | ✅ | |
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ spec:
|
|||||||
components:
|
components:
|
||||||
ComponentTemplate:
|
ComponentTemplate:
|
||||||
fromFiles:
|
fromFiles:
|
||||||
- pkg/platform/adapter.yaml
|
|
||||||
- pkg/provider/modelmgr/requester.yaml
|
- pkg/provider/modelmgr/requester.yaml
|
||||||
MessagePlatformAdapter:
|
MessagePlatformAdapter:
|
||||||
fromDirs:
|
fromDirs:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from quart import request
|
|||||||
import httpx
|
import httpx
|
||||||
from quart import Quart
|
from quart import Quart
|
||||||
from typing import Callable, Dict, Any
|
from typing import Callable, Dict, Any
|
||||||
from pkg.platform.types import events as platform_events
|
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||||
from .qqofficialevent import QQOfficialEvent
|
from .qqofficialevent import QQOfficialEvent
|
||||||
import json
|
import json
|
||||||
import traceback
|
import traceback
|
||||||
@@ -104,7 +104,7 @@ class QQOfficialClient:
|
|||||||
return {'code': 0, 'message': 'success'}
|
return {'code': 0, 'message': 'success'}
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self.logger.error(f"Error in handle_callback_request: {traceback.format_exc()}")
|
await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
|
||||||
return {'error': str(e)}, 400
|
return {'error': str(e)}, 400
|
||||||
|
|
||||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||||
@@ -168,7 +168,6 @@ class QQOfficialClient:
|
|||||||
if not await self.check_access_token():
|
if not await self.check_access_token():
|
||||||
await self.get_access_token()
|
await self.get_access_token()
|
||||||
|
|
||||||
|
|
||||||
url = self.base_url + '/v2/users/' + user_openid + '/messages'
|
url = self.base_url + '/v2/users/' + user_openid + '/messages'
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
headers = {
|
headers = {
|
||||||
@@ -193,7 +192,6 @@ class QQOfficialClient:
|
|||||||
if not await self.check_access_token():
|
if not await self.check_access_token():
|
||||||
await self.get_access_token()
|
await self.get_access_token()
|
||||||
|
|
||||||
|
|
||||||
url = self.base_url + '/v2/groups/' + group_openid + '/messages'
|
url = self.base_url + '/v2/groups/' + group_openid + '/messages'
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
headers = {
|
headers = {
|
||||||
@@ -209,7 +207,7 @@ class QQOfficialClient:
|
|||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
await self.logger.error(f"发送群聊消息失败:{response.json()}")
|
await self.logger.error(f'发送群聊消息失败:{response.json()}')
|
||||||
raise Exception(response.read().decode())
|
raise Exception(response.read().decode())
|
||||||
|
|
||||||
async def send_channle_group_text_msg(self, channel_id: str, content: str, msg_id: str):
|
async def send_channle_group_text_msg(self, channel_id: str, content: str, msg_id: str):
|
||||||
@@ -217,7 +215,6 @@ class QQOfficialClient:
|
|||||||
if not await self.check_access_token():
|
if not await self.check_access_token():
|
||||||
await self.get_access_token()
|
await self.get_access_token()
|
||||||
|
|
||||||
|
|
||||||
url = self.base_url + '/channels/' + channel_id + '/messages'
|
url = self.base_url + '/channels/' + channel_id + '/messages'
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
headers = {
|
headers = {
|
||||||
@@ -240,7 +237,6 @@ class QQOfficialClient:
|
|||||||
"""发送频道私聊消息"""
|
"""发送频道私聊消息"""
|
||||||
if not await self.check_access_token():
|
if not await self.check_access_token():
|
||||||
await self.get_access_token()
|
await self.get_access_token()
|
||||||
|
|
||||||
|
|
||||||
url = self.base_url + '/dms/' + guild_id + '/messages'
|
url = self.base_url + '/dms/' + guild_id + '/messages'
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ from quart import Quart, jsonify, request
|
|||||||
from slack_sdk.web.async_client import AsyncWebClient
|
from slack_sdk.web.async_client import AsyncWebClient
|
||||||
from .slackevent import SlackEvent
|
from .slackevent import SlackEvent
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from pkg.platform.types import events as platform_events
|
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||||
|
|
||||||
|
|
||||||
class SlackClient:
|
class SlackClient:
|
||||||
@@ -34,7 +34,6 @@ class SlackClient:
|
|||||||
|
|
||||||
if self.bot_user_id and bot_user_id == self.bot_user_id:
|
if self.bot_user_id and bot_user_id == self.bot_user_id:
|
||||||
return jsonify({'status': 'ok'})
|
return jsonify({'status': 'ok'})
|
||||||
|
|
||||||
|
|
||||||
# 处理私信
|
# 处理私信
|
||||||
if data and data.get('event', {}).get('channel_type') in ['im']:
|
if data and data.get('event', {}).get('channel_type') in ['im']:
|
||||||
@@ -52,7 +51,7 @@ class SlackClient:
|
|||||||
return jsonify({'status': 'ok'})
|
return jsonify({'status': 'ok'})
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self.logger.error(f"Error in handle_callback_request: {traceback.format_exc()}")
|
await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
|
||||||
raise (e)
|
raise (e)
|
||||||
|
|
||||||
async def _handle_message(self, event: SlackEvent):
|
async def _handle_message(self, event: SlackEvent):
|
||||||
@@ -82,7 +81,7 @@ class SlackClient:
|
|||||||
self.bot_user_id = response['message']['bot_id']
|
self.bot_user_id = response['message']['bot_id']
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self.logger.error(f"Error in send_message: {e}")
|
await self.logger.error(f'Error in send_message: {e}')
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def send_message_to_one(self, text: str, user_id: str):
|
async def send_message_to_one(self, text: str, user_id: str):
|
||||||
@@ -93,7 +92,7 @@ class SlackClient:
|
|||||||
|
|
||||||
return
|
return
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self.logger.error(f"Error in send_message: {traceback.format_exc()}")
|
await self.logger.error(f'Error in send_message: {traceback.format_exc()}')
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||||
|
|||||||
@@ -1 +1,4 @@
|
|||||||
from .client import WeChatPadClient
|
from .client import WeChatPadClient
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ['WeChatPadClient']
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
from libs.wechatpad_api.util.http_util import async_request, post_json
|
from libs.wechatpad_api.util.http_util import post_json
|
||||||
|
|
||||||
|
|
||||||
class ChatRoomApi:
|
class ChatRoomApi:
|
||||||
@@ -7,8 +7,6 @@ class ChatRoomApi:
|
|||||||
self.token = token
|
self.token = token
|
||||||
|
|
||||||
def get_chatroom_member_detail(self, chatroom_name):
|
def get_chatroom_member_detail(self, chatroom_name):
|
||||||
params = {
|
params = {'ChatRoomName': chatroom_name}
|
||||||
"ChatRoomName": chatroom_name
|
|
||||||
}
|
|
||||||
url = self.base_url + '/group/GetChatroomMemberDetail'
|
url = self.base_url + '/group/GetChatroomMemberDetail'
|
||||||
return post_json(url, token=self.token, data=params)
|
return post_json(url, token=self.token, data=params)
|
||||||
|
|||||||
@@ -1,32 +1,23 @@
|
|||||||
from libs.wechatpad_api.util.http_util import async_request, post_json
|
from libs.wechatpad_api.util.http_util import post_json
|
||||||
import httpx
|
import httpx
|
||||||
import base64
|
import base64
|
||||||
|
|
||||||
|
|
||||||
class DownloadApi:
|
class DownloadApi:
|
||||||
def __init__(self, base_url, token):
|
def __init__(self, base_url, token):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.token = token
|
self.token = token
|
||||||
|
|
||||||
def send_download(self, aeskey, file_type, file_url):
|
def send_download(self, aeskey, file_type, file_url):
|
||||||
json_data = {
|
json_data = {'AesKey': aeskey, 'FileType': file_type, 'FileURL': file_url}
|
||||||
"AesKey": aeskey,
|
url = self.base_url + '/message/SendCdnDownload'
|
||||||
"FileType": file_type,
|
|
||||||
"FileURL": file_url
|
|
||||||
}
|
|
||||||
url = self.base_url + "/message/SendCdnDownload"
|
|
||||||
return post_json(url, token=self.token, data=json_data)
|
return post_json(url, token=self.token, data=json_data)
|
||||||
|
|
||||||
def get_msg_voice(self,buf_id, length, new_msgid):
|
def get_msg_voice(self, buf_id, length, new_msgid):
|
||||||
json_data = {
|
json_data = {'Bufid': buf_id, 'Length': length, 'NewMsgId': new_msgid, 'ToUserName': ''}
|
||||||
"Bufid": buf_id,
|
url = self.base_url + '/message/GetMsgVoice'
|
||||||
"Length": length,
|
|
||||||
"NewMsgId": new_msgid,
|
|
||||||
"ToUserName": ""
|
|
||||||
}
|
|
||||||
url = self.base_url + "/message/GetMsgVoice"
|
|
||||||
return post_json(url, token=self.token, data=json_data)
|
return post_json(url, token=self.token, data=json_data)
|
||||||
|
|
||||||
|
|
||||||
async def download_url_to_base64(self, download_url):
|
async def download_url_to_base64(self, download_url):
|
||||||
async with httpx.AsyncClient() as client:
|
async with httpx.AsyncClient() as client:
|
||||||
response = await client.get(download_url)
|
response = await client.get(download_url)
|
||||||
@@ -36,4 +27,4 @@ class DownloadApi:
|
|||||||
base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式
|
base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式
|
||||||
return base64_str
|
return base64_str
|
||||||
else:
|
else:
|
||||||
raise Exception('获取文件失败')
|
raise Exception('获取文件失败')
|
||||||
|
|||||||
@@ -1,11 +1,6 @@
|
|||||||
from libs.wechatpad_api.util.http_util import post_json,async_request
|
|
||||||
from typing import List, Dict, Any, Optional
|
|
||||||
|
|
||||||
|
|
||||||
class FriendApi:
|
class FriendApi:
|
||||||
"""联系人API类,处理所有与联系人相关的操作"""
|
"""联系人API类,处理所有与联系人相关的操作"""
|
||||||
|
|
||||||
def __init__(self, base_url: str, token: str):
|
def __init__(self, base_url: str, token: str):
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.token = token
|
self.token = token
|
||||||
|
|
||||||
|
|||||||
@@ -1,37 +1,34 @@
|
|||||||
from libs.wechatpad_api.util.http_util import async_request,post_json,get_json
|
from libs.wechatpad_api.util.http_util import post_json, get_json
|
||||||
|
|
||||||
|
|
||||||
class LoginApi:
|
class LoginApi:
|
||||||
def __init__(self, base_url: str, token: str = None, admin_key: str = None):
|
def __init__(self, base_url: str, token: str = None, admin_key: str = None):
|
||||||
'''
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
base_url: 原始路径
|
base_url: 原始路径
|
||||||
token: token
|
token: token
|
||||||
admin_key: 管理员key
|
admin_key: 管理员key
|
||||||
'''
|
"""
|
||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.token = token
|
self.token = token
|
||||||
# self.admin_key = admin_key
|
# self.admin_key = admin_key
|
||||||
|
|
||||||
def get_token(self, admin_key, day: int=365):
|
def get_token(self, admin_key, day: int = 365):
|
||||||
# 获取普通token
|
# 获取普通token
|
||||||
url = f"{self.base_url}/admin/GenAuthKey1"
|
url = f'{self.base_url}/admin/GenAuthKey1'
|
||||||
json_data = {
|
json_data = {'Count': 1, 'Days': day}
|
||||||
"Count": 1,
|
|
||||||
"Days": day
|
|
||||||
}
|
|
||||||
return post_json(base_url=url, token=admin_key, data=json_data)
|
return post_json(base_url=url, token=admin_key, data=json_data)
|
||||||
|
|
||||||
def get_login_qr(self, Proxy: str = ""):
|
def get_login_qr(self, Proxy: str = ''):
|
||||||
'''
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
Proxy:异地使用时代理
|
Proxy:异地使用时代理
|
||||||
|
|
||||||
Returns:json数据
|
Returns:json数据
|
||||||
|
|
||||||
'''
|
"""
|
||||||
"""
|
"""
|
||||||
|
|
||||||
{
|
{
|
||||||
@@ -49,54 +46,37 @@ class LoginApi:
|
|||||||
}
|
}
|
||||||
|
|
||||||
"""
|
"""
|
||||||
#获取登录二维码
|
# 获取登录二维码
|
||||||
url = f"{self.base_url}/login/GetLoginQrCodeNew"
|
url = f'{self.base_url}/login/GetLoginQrCodeNew'
|
||||||
check = False
|
check = False
|
||||||
if Proxy != "":
|
if Proxy != '':
|
||||||
check = True
|
check = True
|
||||||
json_data = {
|
json_data = {'Check': check, 'Proxy': Proxy}
|
||||||
"Check": check,
|
|
||||||
"Proxy": Proxy
|
|
||||||
}
|
|
||||||
return post_json(base_url=url, token=self.token, data=json_data)
|
return post_json(base_url=url, token=self.token, data=json_data)
|
||||||
|
|
||||||
|
|
||||||
def get_login_status(self):
|
def get_login_status(self):
|
||||||
# 获取登录状态
|
# 获取登录状态
|
||||||
url = f'{self.base_url}/login/GetLoginStatus'
|
url = f'{self.base_url}/login/GetLoginStatus'
|
||||||
return get_json(base_url=url, token=self.token)
|
return get_json(base_url=url, token=self.token)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def logout(self):
|
def logout(self):
|
||||||
# 退出登录
|
# 退出登录
|
||||||
url = f'{self.base_url}/login/LogOut'
|
url = f'{self.base_url}/login/LogOut'
|
||||||
return post_json(base_url=url, token=self.token)
|
return post_json(base_url=url, token=self.token)
|
||||||
|
|
||||||
|
def wake_up_login(self, Proxy: str = ''):
|
||||||
|
|
||||||
|
|
||||||
def wake_up_login(self, Proxy: str = ""):
|
|
||||||
# 唤醒登录
|
# 唤醒登录
|
||||||
url = f'{self.base_url}/login/WakeUpLogin'
|
url = f'{self.base_url}/login/WakeUpLogin'
|
||||||
check = False
|
check = False
|
||||||
if Proxy != "":
|
if Proxy != '':
|
||||||
check = True
|
check = True
|
||||||
json_data = {
|
json_data = {'Check': check, 'Proxy': ''}
|
||||||
"Check": check,
|
|
||||||
"Proxy": ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return post_json(base_url=url, token=self.token, data=json_data)
|
return post_json(base_url=url, token=self.token, data=json_data)
|
||||||
|
|
||||||
|
def login(self, admin_key):
|
||||||
|
|
||||||
def login(self,admin_key):
|
|
||||||
login_status = self.get_login_status()
|
login_status = self.get_login_status()
|
||||||
if login_status["Code"] == 300 and login_status["Text"] == "你已退出微信":
|
if login_status['Code'] == 300 and login_status['Text'] == '你已退出微信':
|
||||||
print("token已经失效,重新获取")
|
print('token已经失效,重新获取')
|
||||||
token_data = self.get_token(admin_key)
|
token_data = self.get_token(admin_key)
|
||||||
self.token = token_data["Data"][0]
|
self.token = token_data['Data'][0]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
|
from libs.wechatpad_api.util.http_util import post_json
|
||||||
from libs.wechatpad_api.util.http_util import async_request, post_json
|
|
||||||
|
|
||||||
|
|
||||||
class MessageApi:
|
class MessageApi:
|
||||||
@@ -7,8 +6,8 @@ class MessageApi:
|
|||||||
self.base_url = base_url
|
self.base_url = base_url
|
||||||
self.token = token
|
self.token = token
|
||||||
|
|
||||||
def post_text(self, to_wxid, content, ats: list= []):
|
def post_text(self, to_wxid, content, ats: list = []):
|
||||||
'''
|
"""
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
app_id: 微信id
|
app_id: 微信id
|
||||||
@@ -18,106 +17,64 @@ class MessageApi:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
|
||||||
'''
|
"""
|
||||||
url = self.base_url + "/message/SendTextMessage"
|
url = self.base_url + '/message/SendTextMessage'
|
||||||
"""发送文字消息"""
|
"""发送文字消息"""
|
||||||
json_data = {
|
json_data = {
|
||||||
"MsgItem": [
|
'MsgItem': [
|
||||||
{
|
{'AtWxIDList': ats, 'ImageContent': '', 'MsgType': 0, 'TextContent': content, 'ToUserName': to_wxid}
|
||||||
"AtWxIDList": ats,
|
]
|
||||||
"ImageContent": "",
|
}
|
||||||
"MsgType": 0,
|
return post_json(base_url=url, token=self.token, data=json_data)
|
||||||
"TextContent": content,
|
|
||||||
"ToUserName": to_wxid
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
return post_json(base_url=url, token=self.token, data=json_data)
|
|
||||||
|
|
||||||
|
def post_image(self, to_wxid, img_url, ats: list = []):
|
||||||
|
|
||||||
|
|
||||||
def post_image(self, to_wxid, img_url, ats: list= []):
|
|
||||||
"""发送图片消息"""
|
"""发送图片消息"""
|
||||||
# 这里好像可以尝试发送多个暂时未测试
|
# 这里好像可以尝试发送多个暂时未测试
|
||||||
json_data = {
|
json_data = {
|
||||||
"MsgItem": [
|
'MsgItem': [
|
||||||
{
|
{'AtWxIDList': ats, 'ImageContent': img_url, 'MsgType': 0, 'TextContent': '', 'ToUserName': to_wxid}
|
||||||
"AtWxIDList": ats,
|
|
||||||
"ImageContent": img_url,
|
|
||||||
"MsgType": 0,
|
|
||||||
"TextContent": '',
|
|
||||||
"ToUserName": to_wxid
|
|
||||||
}
|
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
url = self.base_url + "/message/SendImageMessage"
|
url = self.base_url + '/message/SendImageMessage'
|
||||||
return post_json(base_url=url, token=self.token, data=json_data)
|
return post_json(base_url=url, token=self.token, data=json_data)
|
||||||
|
|
||||||
def post_voice(self, to_wxid, voice_data, voice_forma, voice_duration):
|
def post_voice(self, to_wxid, voice_data, voice_forma, voice_duration):
|
||||||
"""发送语音消息"""
|
"""发送语音消息"""
|
||||||
json_data = {
|
json_data = {
|
||||||
"ToUserName": to_wxid,
|
'ToUserName': to_wxid,
|
||||||
"VoiceData": voice_data,
|
'VoiceData': voice_data,
|
||||||
"VoiceFormat": voice_forma,
|
'VoiceFormat': voice_forma,
|
||||||
"VoiceSecond": voice_duration
|
'VoiceSecond': voice_duration,
|
||||||
}
|
}
|
||||||
url = self.base_url + "/message/SendVoice"
|
url = self.base_url + '/message/SendVoice'
|
||||||
return post_json(base_url=url, token=self.token, data=json_data)
|
return post_json(base_url=url, token=self.token, data=json_data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def post_name_card(self, alias, to_wxid, nick_name, name_card_wxid, flag):
|
def post_name_card(self, alias, to_wxid, nick_name, name_card_wxid, flag):
|
||||||
"""发送名片消息"""
|
"""发送名片消息"""
|
||||||
param = {
|
param = {
|
||||||
"CardAlias": alias,
|
'CardAlias': alias,
|
||||||
"CardFlag": flag,
|
'CardFlag': flag,
|
||||||
"CardNickName": nick_name,
|
'CardNickName': nick_name,
|
||||||
"CardWxId": name_card_wxid,
|
'CardWxId': name_card_wxid,
|
||||||
"ToUserName": to_wxid
|
'ToUserName': to_wxid,
|
||||||
}
|
}
|
||||||
url = f"{self.base_url}/message/ShareCardMessage"
|
url = f'{self.base_url}/message/ShareCardMessage'
|
||||||
return post_json(base_url=url, token=self.token, data=param)
|
return post_json(base_url=url, token=self.token, data=param)
|
||||||
|
|
||||||
def post_emoji(self, to_wxid, emoji_md5, emoji_size:int=0):
|
def post_emoji(self, to_wxid, emoji_md5, emoji_size: int = 0):
|
||||||
"""发送emoji消息"""
|
"""发送emoji消息"""
|
||||||
json_data = {
|
json_data = {'EmojiList': [{'EmojiMd5': emoji_md5, 'EmojiSize': emoji_size, 'ToUserName': to_wxid}]}
|
||||||
"EmojiList": [
|
url = f'{self.base_url}/message/SendEmojiMessage'
|
||||||
{
|
|
||||||
"EmojiMd5": emoji_md5,
|
|
||||||
"EmojiSize": emoji_size,
|
|
||||||
"ToUserName": to_wxid
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
url = f"{self.base_url}/message/SendEmojiMessage"
|
|
||||||
return post_json(base_url=url, token=self.token, data=json_data)
|
return post_json(base_url=url, token=self.token, data=json_data)
|
||||||
|
|
||||||
def post_app_msg(self, to_wxid,xml_data, contenttype:int=0):
|
def post_app_msg(self, to_wxid, xml_data, contenttype: int = 0):
|
||||||
"""发送appmsg消息"""
|
"""发送appmsg消息"""
|
||||||
json_data = {
|
json_data = {'AppList': [{'ContentType': contenttype, 'ContentXML': xml_data, 'ToUserName': to_wxid}]}
|
||||||
"AppList": [
|
url = f'{self.base_url}/message/SendAppMessage'
|
||||||
{
|
|
||||||
"ContentType": contenttype,
|
|
||||||
"ContentXML": xml_data,
|
|
||||||
"ToUserName": to_wxid
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
url = f"{self.base_url}/message/SendAppMessage"
|
|
||||||
return post_json(base_url=url, token=self.token, data=json_data)
|
return post_json(base_url=url, token=self.token, data=json_data)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time):
|
def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time):
|
||||||
"""撤回消息"""
|
"""撤回消息"""
|
||||||
param = {
|
param = {'ClientMsgId': msg_id, 'CreateTime': create_time, 'NewMsgId': new_msg_id, 'ToUserName': to_wxid}
|
||||||
"ClientMsgId": msg_id,
|
url = f'{self.base_url}/message/RevokeMsg'
|
||||||
"CreateTime": create_time,
|
return post_json(base_url=url, token=self.token, data=param)
|
||||||
"NewMsgId": new_msg_id,
|
|
||||||
"ToUserName": to_wxid
|
|
||||||
}
|
|
||||||
url = f"{self.base_url}/message/RevokeMsg"
|
|
||||||
return post_json(base_url=url, token=self.token, data=param)
|
|
||||||
|
|||||||
@@ -12,12 +12,9 @@ class UserApi:
|
|||||||
|
|
||||||
return get_json(base_url=url, token=self.token)
|
return get_json(base_url=url, token=self.token)
|
||||||
|
|
||||||
def get_qr_code(self, recover:bool=True, style:int=8):
|
def get_qr_code(self, recover: bool = True, style: int = 8):
|
||||||
"""获取自己的二维码"""
|
"""获取自己的二维码"""
|
||||||
param = {
|
param = {'Recover': recover, 'Style': style}
|
||||||
"Recover": recover,
|
|
||||||
"Style": style
|
|
||||||
}
|
|
||||||
url = f'{self.base_url}/user/GetMyQRCode'
|
url = f'{self.base_url}/user/GetMyQRCode'
|
||||||
return post_json(base_url=url, token=self.token, data=param)
|
return post_json(base_url=url, token=self.token, data=param)
|
||||||
|
|
||||||
@@ -26,12 +23,8 @@ class UserApi:
|
|||||||
url = f'{self.base_url}/equipment/GetSafetyInfo'
|
url = f'{self.base_url}/equipment/GetSafetyInfo'
|
||||||
return post_json(base_url=url, token=self.token)
|
return post_json(base_url=url, token=self.token)
|
||||||
|
|
||||||
|
async def update_head_img(self, head_img_base64):
|
||||||
|
|
||||||
async def update_head_img(self, head_img_base64):
|
|
||||||
"""修改头像"""
|
"""修改头像"""
|
||||||
param = {
|
param = {'Base64': head_img_base64}
|
||||||
"Base64": head_img_base64
|
|
||||||
}
|
|
||||||
url = f'{self.base_url}/user/UploadHeadImage'
|
url = f'{self.base_url}/user/UploadHeadImage'
|
||||||
return await async_request(base_url=url, token_key=self.token, json=param)
|
return await async_request(base_url=url, token_key=self.token, json=param)
|
||||||
|
|||||||
@@ -1,4 +1,3 @@
|
|||||||
|
|
||||||
from libs.wechatpad_api.api.login import LoginApi
|
from libs.wechatpad_api.api.login import LoginApi
|
||||||
from libs.wechatpad_api.api.friend import FriendApi
|
from libs.wechatpad_api.api.friend import FriendApi
|
||||||
from libs.wechatpad_api.api.message import MessageApi
|
from libs.wechatpad_api.api.message import MessageApi
|
||||||
@@ -7,9 +6,6 @@ from libs.wechatpad_api.api.downloadpai import DownloadApi
|
|||||||
from libs.wechatpad_api.api.chatroom import ChatRoomApi
|
from libs.wechatpad_api.api.chatroom import ChatRoomApi
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class WeChatPadClient:
|
class WeChatPadClient:
|
||||||
def __init__(self, base_url, token, logger=None):
|
def __init__(self, base_url, token, logger=None):
|
||||||
self._login_api = LoginApi(base_url, token)
|
self._login_api = LoginApi(base_url, token)
|
||||||
@@ -20,16 +16,16 @@ class WeChatPadClient:
|
|||||||
self._chatroom_api = ChatRoomApi(base_url, token)
|
self._chatroom_api = ChatRoomApi(base_url, token)
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
|
|
||||||
def get_token(self,admin_key, day: int):
|
def get_token(self, admin_key, day: int):
|
||||||
'''获取token'''
|
"""获取token"""
|
||||||
return self._login_api.get_token(admin_key, day)
|
return self._login_api.get_token(admin_key, day)
|
||||||
|
|
||||||
def get_login_qr(self, Proxy:str=""):
|
def get_login_qr(self, Proxy: str = ''):
|
||||||
"""登录二维码"""
|
"""登录二维码"""
|
||||||
return self._login_api.get_login_qr(Proxy=Proxy)
|
return self._login_api.get_login_qr(Proxy=Proxy)
|
||||||
|
|
||||||
def awaken_login(self, Proxy:str=""):
|
def awaken_login(self, Proxy: str = ''):
|
||||||
'''唤醒登录'''
|
"""唤醒登录"""
|
||||||
return self._login_api.wake_up_login(Proxy=Proxy)
|
return self._login_api.wake_up_login(Proxy=Proxy)
|
||||||
|
|
||||||
def log_out(self):
|
def log_out(self):
|
||||||
@@ -40,59 +36,57 @@ class WeChatPadClient:
|
|||||||
"""获取登录状态"""
|
"""获取登录状态"""
|
||||||
return self._login_api.get_login_status()
|
return self._login_api.get_login_status()
|
||||||
|
|
||||||
def send_text_message(self, to_wxid, message, ats: list=[]):
|
def send_text_message(self, to_wxid, message, ats: list = []):
|
||||||
"""发送文本消息"""
|
"""发送文本消息"""
|
||||||
return self._message_api.post_text(to_wxid, message, ats)
|
return self._message_api.post_text(to_wxid, message, ats)
|
||||||
|
|
||||||
def send_image_message(self, to_wxid, img_url, ats: list=[]):
|
def send_image_message(self, to_wxid, img_url, ats: list = []):
|
||||||
"""发送图片消息"""
|
"""发送图片消息"""
|
||||||
return self._message_api.post_image(to_wxid, img_url, ats)
|
return self._message_api.post_image(to_wxid, img_url, ats)
|
||||||
|
|
||||||
def send_voice_message(self, to_wxid, voice_data, voice_forma, voice_duration):
|
def send_voice_message(self, to_wxid, voice_data, voice_forma, voice_duration):
|
||||||
"""发送音频消息"""
|
"""发送音频消息"""
|
||||||
return self._message_api.post_voice(to_wxid, voice_data, voice_forma, voice_duration)
|
return self._message_api.post_voice(to_wxid, voice_data, voice_forma, voice_duration)
|
||||||
|
|
||||||
def send_app_message(self, to_wxid, app_message, type):
|
def send_app_message(self, to_wxid, app_message, type):
|
||||||
"""发送app消息"""
|
"""发送app消息"""
|
||||||
return self._message_api.post_app_msg(to_wxid, app_message, type)
|
return self._message_api.post_app_msg(to_wxid, app_message, type)
|
||||||
|
|
||||||
def send_emoji_message(self, to_wxid, emoji_md5, emoji_size):
|
def send_emoji_message(self, to_wxid, emoji_md5, emoji_size):
|
||||||
"""发送emoji消息"""
|
"""发送emoji消息"""
|
||||||
return self._message_api.post_emoji(to_wxid,emoji_md5,emoji_size)
|
return self._message_api.post_emoji(to_wxid, emoji_md5, emoji_size)
|
||||||
|
|
||||||
def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time):
|
def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time):
|
||||||
"""撤回消息"""
|
"""撤回消息"""
|
||||||
return self._message_api.revoke_msg(to_wxid, msg_id, new_msg_id, create_time)
|
return self._message_api.revoke_msg(to_wxid, msg_id, new_msg_id, create_time)
|
||||||
|
|
||||||
def get_profile(self):
|
def get_profile(self):
|
||||||
"""获取用户信息"""
|
"""获取用户信息"""
|
||||||
return self._user_api.get_profile()
|
return self._user_api.get_profile()
|
||||||
|
|
||||||
def get_qr_code(self, recover:bool=True, style:int=8):
|
def get_qr_code(self, recover: bool = True, style: int = 8):
|
||||||
"""获取用户二维码"""
|
"""获取用户二维码"""
|
||||||
return self._user_api.get_qr_code(recover=recover, style=style)
|
return self._user_api.get_qr_code(recover=recover, style=style)
|
||||||
|
|
||||||
def get_safety_info(self):
|
def get_safety_info(self):
|
||||||
"""获取设备信息"""
|
"""获取设备信息"""
|
||||||
return self._user_api.get_safety_info()
|
return self._user_api.get_safety_info()
|
||||||
|
|
||||||
def update_head_img(self, head_img_base64):
|
def update_head_img(self, head_img_base64):
|
||||||
"""上传用户头像"""
|
"""上传用户头像"""
|
||||||
return self._user_api.update_head_img(head_img_base64)
|
return self._user_api.update_head_img(head_img_base64)
|
||||||
|
|
||||||
def cdn_download(self, aeskey, file_type, file_url):
|
def cdn_download(self, aeskey, file_type, file_url):
|
||||||
"""cdn下载"""
|
"""cdn下载"""
|
||||||
return self._download_api.send_download( aeskey, file_type, file_url)
|
return self._download_api.send_download(aeskey, file_type, file_url)
|
||||||
|
|
||||||
def get_msg_voice(self,buf_id, length, msgid):
|
def get_msg_voice(self, buf_id, length, msgid):
|
||||||
"""下载语音"""
|
"""下载语音"""
|
||||||
return self._download_api.get_msg_voice(buf_id, length, msgid)
|
return self._download_api.get_msg_voice(buf_id, length, msgid)
|
||||||
|
|
||||||
async def download_base64(self,url):
|
async def download_base64(self, url):
|
||||||
return await self._download_api.download_url_to_base64(download_url=url)
|
return await self._download_api.download_url_to_base64(download_url=url)
|
||||||
|
|
||||||
def get_chatroom_member_detail(self, chatroom_name):
|
def get_chatroom_member_detail(self, chatroom_name):
|
||||||
"""查看群成员详情"""
|
"""查看群成员详情"""
|
||||||
return self._chatroom_api.get_chatroom_member_detail(chatroom_name)
|
return self._chatroom_api.get_chatroom_member_detail(chatroom_name)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
import requests
|
import requests
|
||||||
|
import aiohttp
|
||||||
|
|
||||||
|
|
||||||
def post_json(base_url, token, data=None):
|
def post_json(base_url, token, data=None):
|
||||||
headers = {
|
headers = {'Content-Type': 'application/json'}
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
url = base_url + f'?key={token}'
|
url = base_url + f'?key={token}'
|
||||||
|
|
||||||
@@ -18,14 +17,12 @@ def post_json(base_url, token, data=None):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(response.text)
|
raise RuntimeError(response.text)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"http请求失败, url={url}, exception={e}")
|
print(f'http请求失败, url={url}, exception={e}')
|
||||||
raise RuntimeError(str(e))
|
raise RuntimeError(str(e))
|
||||||
|
|
||||||
def get_json(base_url, token):
|
|
||||||
headers = {
|
|
||||||
'Content-Type': 'application/json'
|
|
||||||
}
|
|
||||||
|
|
||||||
|
def get_json(base_url, token):
|
||||||
|
headers = {'Content-Type': 'application/json'}
|
||||||
|
|
||||||
url = base_url + f'?key={token}'
|
url = base_url + f'?key={token}'
|
||||||
|
|
||||||
@@ -39,21 +36,18 @@ def get_json(base_url, token):
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError(response.text)
|
raise RuntimeError(response.text)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"http请求失败, url={url}, exception={e}")
|
print(f'http请求失败, url={url}, exception={e}')
|
||||||
raise RuntimeError(str(e))
|
raise RuntimeError(str(e))
|
||||||
|
|
||||||
import aiohttp
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
|
|
||||||
async def async_request(
|
async def async_request(
|
||||||
base_url: str,
|
base_url: str,
|
||||||
token_key: str,
|
token_key: str,
|
||||||
method: str = 'POST',
|
method: str = 'POST',
|
||||||
params: dict = None,
|
params: dict = None,
|
||||||
# headers: dict = None,
|
# headers: dict = None,
|
||||||
data: dict = None,
|
data: dict = None,
|
||||||
json: dict = None
|
json: dict = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
通用异步请求函数
|
通用异步请求函数
|
||||||
@@ -67,18 +61,11 @@ async def async_request(
|
|||||||
:param json: JSON数据
|
:param json: JSON数据
|
||||||
:return: 响应文本
|
:return: 响应文本
|
||||||
"""
|
"""
|
||||||
headers = {
|
headers = {'Content-Type': 'application/json'}
|
||||||
'Content-Type': 'application/json'
|
url = f'{base_url}?key={token_key}'
|
||||||
}
|
|
||||||
url = f"{base_url}?key={token_key}"
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.request(
|
async with session.request(
|
||||||
method=method,
|
method=method, url=url, params=params, headers=headers, data=data, json=json
|
||||||
url=url,
|
|
||||||
params=params,
|
|
||||||
headers=headers,
|
|
||||||
data=data,
|
|
||||||
json=json
|
|
||||||
) as response:
|
) as response:
|
||||||
response.raise_for_status() # 如果状态码不是200,抛出异常
|
response.raise_for_status() # 如果状态码不是200,抛出异常
|
||||||
result = await response.json()
|
result = await response.json()
|
||||||
@@ -89,4 +76,3 @@ async def async_request(
|
|||||||
# return await result
|
# return await result
|
||||||
# else:
|
# else:
|
||||||
# raise RuntimeError("请求失败",response.text)
|
# raise RuntimeError("请求失败",response.text)
|
||||||
|
|
||||||
|
|||||||
@@ -1,31 +1,34 @@
|
|||||||
import qrcode
|
import qrcode
|
||||||
|
|
||||||
|
|
||||||
def print_green(text):
|
def print_green(text):
|
||||||
print(f"\033[32m{text}\033[0m")
|
print(f'\033[32m{text}\033[0m')
|
||||||
|
|
||||||
|
|
||||||
def print_yellow(text):
|
def print_yellow(text):
|
||||||
print(f"\033[33m{text}\033[0m")
|
print(f'\033[33m{text}\033[0m')
|
||||||
|
|
||||||
|
|
||||||
def print_red(text):
|
def print_red(text):
|
||||||
print(f"\033[31m{text}\033[0m")
|
print(f'\033[31m{text}\033[0m')
|
||||||
|
|
||||||
|
|
||||||
def make_and_print_qr(url):
|
def make_and_print_qr(url):
|
||||||
"""生成并打印二维码
|
"""生成并打印二维码
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
url: 需要生成二维码的URL字符串
|
url: 需要生成二维码的URL字符串
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
|
|
||||||
功能:
|
功能:
|
||||||
1. 在终端打印二维码的ASCII图形
|
1. 在终端打印二维码的ASCII图形
|
||||||
2. 同时提供在线二维码生成链接作为备选
|
2. 同时提供在线二维码生成链接作为备选
|
||||||
"""
|
"""
|
||||||
print_green("请扫描下方二维码登录")
|
print_green('请扫描下方二维码登录')
|
||||||
qr = qrcode.QRCode()
|
qr = qrcode.QRCode()
|
||||||
qr.add_data(url)
|
qr.add_data(url)
|
||||||
qr.make()
|
qr.make()
|
||||||
qr.print_ascii(invert=True)
|
qr.print_ascii(invert=True)
|
||||||
print_green(f"也可以访问下方链接获取二维码:\nhttps://api.qrserver.com/v1/create-qr-code/?data={url}")
|
print_green(f'也可以访问下方链接获取二维码:\nhttps://api.qrserver.com/v1/create-qr-code/?data={url}')
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from quart import Quart
|
|||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from typing import Callable, Dict, Any
|
from typing import Callable, Dict, Any
|
||||||
from .wecomevent import WecomEvent
|
from .wecomevent import WecomEvent
|
||||||
from pkg.platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
import aiofiles
|
import aiofiles
|
||||||
|
|
||||||
|
|
||||||
@@ -57,7 +57,7 @@ class WecomClient:
|
|||||||
if 'access_token' in data:
|
if 'access_token' in data:
|
||||||
return data['access_token']
|
return data['access_token']
|
||||||
else:
|
else:
|
||||||
await self.logger.error(f"获取accesstoken失败:{response.json()}")
|
await self.logger.error(f'获取accesstoken失败:{response.json()}')
|
||||||
raise Exception(f'未获取access token: {data}')
|
raise Exception(f'未获取access token: {data}')
|
||||||
|
|
||||||
async def get_users(self):
|
async def get_users(self):
|
||||||
@@ -129,7 +129,7 @@ class WecomClient:
|
|||||||
response = await client.post(url, json=params)
|
response = await client.post(url, json=params)
|
||||||
data = response.json()
|
data = response.json()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self.logger.error(f"发送图片失败:{data}")
|
await self.logger.error(f'发送图片失败:{data}')
|
||||||
raise Exception('Failed to send image: ' + str(e))
|
raise Exception('Failed to send image: ' + str(e))
|
||||||
|
|
||||||
# 企业微信错误码40014和42001,代表accesstoken问题
|
# 企业微信错误码40014和42001,代表accesstoken问题
|
||||||
@@ -164,7 +164,7 @@ class WecomClient:
|
|||||||
self.access_token = await self.get_access_token(self.secret)
|
self.access_token = await self.get_access_token(self.secret)
|
||||||
return await self.send_private_msg(user_id, agent_id, content)
|
return await self.send_private_msg(user_id, agent_id, content)
|
||||||
if data['errcode'] != 0:
|
if data['errcode'] != 0:
|
||||||
await self.logger.error(f"发送消息失败:{data}")
|
await self.logger.error(f'发送消息失败:{data}')
|
||||||
raise Exception('Failed to send message: ' + str(data))
|
raise Exception('Failed to send message: ' + str(data))
|
||||||
|
|
||||||
async def handle_callback_request(self):
|
async def handle_callback_request(self):
|
||||||
@@ -181,7 +181,7 @@ class WecomClient:
|
|||||||
echostr = request.args.get('echostr')
|
echostr = request.args.get('echostr')
|
||||||
ret, reply_echo_str = wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
|
ret, reply_echo_str = wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
await self.logger.error("验证失败")
|
await self.logger.error('验证失败')
|
||||||
raise Exception(f'验证失败,错误码: {ret}')
|
raise Exception(f'验证失败,错误码: {ret}')
|
||||||
return reply_echo_str
|
return reply_echo_str
|
||||||
|
|
||||||
@@ -189,9 +189,8 @@ class WecomClient:
|
|||||||
encrypt_msg = await request.data
|
encrypt_msg = await request.data
|
||||||
ret, xml_msg = wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
|
ret, xml_msg = wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
|
||||||
if ret != 0:
|
if ret != 0:
|
||||||
await self.logger.error("消息解密失败")
|
await self.logger.error('消息解密失败')
|
||||||
raise Exception(f'消息解密失败,错误码: {ret}')
|
raise Exception(f'消息解密失败,错误码: {ret}')
|
||||||
|
|
||||||
|
|
||||||
# 解析消息并处理
|
# 解析消息并处理
|
||||||
message_data = await self.get_message(xml_msg)
|
message_data = await self.get_message(xml_msg)
|
||||||
@@ -202,7 +201,7 @@ class WecomClient:
|
|||||||
|
|
||||||
return 'success'
|
return 'success'
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
await self.logger.error(f"Error in handle_callback_request: {traceback.format_exc()}")
|
await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
|
||||||
return f'Error processing request: {str(e)}', 400
|
return f'Error processing request: {str(e)}', 400
|
||||||
|
|
||||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||||
@@ -301,7 +300,7 @@ class WecomClient:
|
|||||||
except binascii.Error as e:
|
except binascii.Error as e:
|
||||||
raise ValueError(f'Invalid base64 string: {str(e)}')
|
raise ValueError(f'Invalid base64 string: {str(e)}')
|
||||||
else:
|
else:
|
||||||
await self.logger.error("Image对象出错")
|
await self.logger.error('Image对象出错')
|
||||||
raise ValueError('image对象出错')
|
raise ValueError('image对象出错')
|
||||||
|
|
||||||
# 设置 multipart/form-data 格式的文件
|
# 设置 multipart/form-data 格式的文件
|
||||||
@@ -325,7 +324,7 @@ class WecomClient:
|
|||||||
self.access_token = await self.get_access_token(self.secret)
|
self.access_token = await self.get_access_token(self.secret)
|
||||||
media_id = await self.upload_to_work(image)
|
media_id = await self.upload_to_work(image)
|
||||||
if data.get('errcode', 0) != 0:
|
if data.get('errcode', 0) != 0:
|
||||||
await self.logger.error(f"上传图片失败:{data}")
|
await self.logger.error(f'上传图片失败:{data}')
|
||||||
raise Exception('failed to upload file')
|
raise Exception('failed to upload file')
|
||||||
|
|
||||||
media_id = data.get('media_id')
|
media_id = data.get('media_id')
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ from quart import Quart
|
|||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from typing import Callable
|
from typing import Callable
|
||||||
from .wecomcsevent import WecomCSEvent
|
from .wecomcsevent import WecomCSEvent
|
||||||
from pkg.platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
import aiofiles
|
import aiofiles
|
||||||
|
|
||||||
|
|
||||||
@@ -187,7 +187,7 @@ class WecomCSClient:
|
|||||||
self.access_token = await self.get_access_token(self.secret)
|
self.access_token = await self.get_access_token(self.secret)
|
||||||
return await self.send_text_msg(open_kfid, external_userid, msgid, content)
|
return await self.send_text_msg(open_kfid, external_userid, msgid, content)
|
||||||
if data['errcode'] != 0:
|
if data['errcode'] != 0:
|
||||||
await self.logger.error(f"发送消息失败:{data}")
|
await self.logger.error(f'发送消息失败:{data}')
|
||||||
raise Exception('Failed to send message')
|
raise Exception('Failed to send message')
|
||||||
return data
|
return data
|
||||||
|
|
||||||
@@ -227,7 +227,7 @@ class WecomCSClient:
|
|||||||
return 'success'
|
return 'success'
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if self.logger:
|
if self.logger:
|
||||||
await self.logger.error(f"Error in handle_callback_request: {traceback.format_exc()}")
|
await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
|
||||||
else:
|
else:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
return f'Error processing request: {str(e)}', 400
|
return f'Error processing request: {str(e)}', 400
|
||||||
|
|||||||
10
main.py
10
main.py
@@ -47,13 +47,13 @@ async def main_entry(loop: asyncio.AbstractEventLoop):
|
|||||||
if not args.skip_plugin_deps_check:
|
if not args.skip_plugin_deps_check:
|
||||||
await deps.precheck_plugin_deps()
|
await deps.precheck_plugin_deps()
|
||||||
|
|
||||||
# 检查pydantic版本,如果没有 pydantic.v1,则把 pydantic 映射为 v1
|
# # 检查pydantic版本,如果没有 pydantic.v1,则把 pydantic 映射为 v1
|
||||||
import pydantic.version
|
# import pydantic.version
|
||||||
|
|
||||||
if pydantic.version.VERSION < '2.0':
|
# if pydantic.version.VERSION < '2.0':
|
||||||
import pydantic
|
# import pydantic
|
||||||
|
|
||||||
sys.modules['pydantic.v1'] = pydantic
|
# sys.modules['pydantic.v1'] = pydantic
|
||||||
|
|
||||||
# 检查配置文件
|
# 检查配置文件
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
import quart
|
import quart
|
||||||
|
|
||||||
from .....core import taskmgr
|
from .....core import taskmgr
|
||||||
from .. import group
|
from .. import group
|
||||||
|
from langbot_plugin.runtime.plugin.mgr import PluginInstallSource
|
||||||
|
|
||||||
|
|
||||||
@group.group_class('plugins', '/api/v1/plugins')
|
@group.group_class('plugins', '/api/v1/plugins')
|
||||||
@@ -12,35 +13,22 @@ class PluginsRouterGroup(group.RouterGroup):
|
|||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
@self.route('', methods=['GET'], auth_type=group.AuthType.USER_TOKEN)
|
||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
plugins = self.ap.plugin_mgr.plugins()
|
plugins = await self.ap.plugin_connector.list_plugins()
|
||||||
|
|
||||||
plugins_data = [plugin.model_dump() for plugin in plugins]
|
return self.success(data={'plugins': plugins})
|
||||||
|
|
||||||
return self.success(data={'plugins': plugins_data})
|
|
||||||
|
|
||||||
@self.route(
|
@self.route(
|
||||||
'/<author>/<plugin_name>/toggle',
|
'/<author>/<plugin_name>/upgrade',
|
||||||
methods=['PUT'],
|
|
||||||
auth_type=group.AuthType.USER_TOKEN,
|
|
||||||
)
|
|
||||||
async def _(author: str, plugin_name: str) -> str:
|
|
||||||
data = await quart.request.json
|
|
||||||
target_enabled = data.get('target_enabled')
|
|
||||||
await self.ap.plugin_mgr.update_plugin_switch(plugin_name, target_enabled)
|
|
||||||
return self.success()
|
|
||||||
|
|
||||||
@self.route(
|
|
||||||
'/<author>/<plugin_name>/update',
|
|
||||||
methods=['POST'],
|
methods=['POST'],
|
||||||
auth_type=group.AuthType.USER_TOKEN,
|
auth_type=group.AuthType.USER_TOKEN,
|
||||||
)
|
)
|
||||||
async def _(author: str, plugin_name: str) -> str:
|
async def _(author: str, plugin_name: str) -> str:
|
||||||
ctx = taskmgr.TaskContext.new()
|
ctx = taskmgr.TaskContext.new()
|
||||||
wrapper = self.ap.task_mgr.create_user_task(
|
wrapper = self.ap.task_mgr.create_user_task(
|
||||||
self.ap.plugin_mgr.update_plugin(plugin_name, task_context=ctx),
|
self.ap.plugin_connector.upgrade_plugin(author, plugin_name, task_context=ctx),
|
||||||
kind='plugin-operation',
|
kind='plugin-operation',
|
||||||
name=f'plugin-update-{plugin_name}',
|
name=f'plugin-upgrade-{plugin_name}',
|
||||||
label=f'更新插件 {plugin_name}',
|
label=f'Upgrading plugin {plugin_name}',
|
||||||
context=ctx,
|
context=ctx,
|
||||||
)
|
)
|
||||||
return self.success(data={'task_id': wrapper.id})
|
return self.success(data={'task_id': wrapper.id})
|
||||||
@@ -52,17 +40,17 @@ class PluginsRouterGroup(group.RouterGroup):
|
|||||||
)
|
)
|
||||||
async def _(author: str, plugin_name: str) -> str:
|
async def _(author: str, plugin_name: str) -> str:
|
||||||
if quart.request.method == 'GET':
|
if quart.request.method == 'GET':
|
||||||
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
|
plugin = await self.ap.plugin_connector.get_plugin_info(author, plugin_name)
|
||||||
if plugin is None:
|
if plugin is None:
|
||||||
return self.http_status(404, -1, 'plugin not found')
|
return self.http_status(404, -1, 'plugin not found')
|
||||||
return self.success(data={'plugin': plugin.model_dump()})
|
return self.success(data={'plugin': plugin})
|
||||||
elif quart.request.method == 'DELETE':
|
elif quart.request.method == 'DELETE':
|
||||||
ctx = taskmgr.TaskContext.new()
|
ctx = taskmgr.TaskContext.new()
|
||||||
wrapper = self.ap.task_mgr.create_user_task(
|
wrapper = self.ap.task_mgr.create_user_task(
|
||||||
self.ap.plugin_mgr.uninstall_plugin(plugin_name, task_context=ctx),
|
self.ap.plugin_connector.delete_plugin(author, plugin_name, task_context=ctx),
|
||||||
kind='plugin-operation',
|
kind='plugin-operation',
|
||||||
name=f'plugin-remove-{plugin_name}',
|
name=f'plugin-remove-{plugin_name}',
|
||||||
label=f'删除插件 {plugin_name}',
|
label=f'Removing plugin {plugin_name}',
|
||||||
context=ctx,
|
context=ctx,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -74,24 +62,19 @@ class PluginsRouterGroup(group.RouterGroup):
|
|||||||
auth_type=group.AuthType.USER_TOKEN,
|
auth_type=group.AuthType.USER_TOKEN,
|
||||||
)
|
)
|
||||||
async def _(author: str, plugin_name: str) -> quart.Response:
|
async def _(author: str, plugin_name: str) -> quart.Response:
|
||||||
plugin = self.ap.plugin_mgr.get_plugin(author, plugin_name)
|
plugin = await self.ap.plugin_connector.get_plugin_info(author, plugin_name)
|
||||||
if plugin is None:
|
if plugin is None:
|
||||||
return self.http_status(404, -1, 'plugin not found')
|
return self.http_status(404, -1, 'plugin not found')
|
||||||
|
|
||||||
if quart.request.method == 'GET':
|
if quart.request.method == 'GET':
|
||||||
return self.success(data={'config': plugin.plugin_config})
|
return self.success(data={'config': plugin['plugin_config']})
|
||||||
elif quart.request.method == 'PUT':
|
elif quart.request.method == 'PUT':
|
||||||
data = await quart.request.json
|
data = await quart.request.json
|
||||||
|
|
||||||
await self.ap.plugin_mgr.set_plugin_config(plugin, data)
|
await self.ap.plugin_connector.set_plugin_config(author, plugin_name, data)
|
||||||
|
|
||||||
return self.success(data={})
|
return self.success(data={})
|
||||||
|
|
||||||
@self.route('/reorder', methods=['PUT'], auth_type=group.AuthType.USER_TOKEN)
|
|
||||||
async def _() -> str:
|
|
||||||
data = await quart.request.json
|
|
||||||
await self.ap.plugin_mgr.reorder_plugins(data.get('plugins'))
|
|
||||||
return self.success()
|
|
||||||
|
|
||||||
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
@self.route('/install/github', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
data = await quart.request.json
|
data = await quart.request.json
|
||||||
@@ -102,7 +85,47 @@ class PluginsRouterGroup(group.RouterGroup):
|
|||||||
self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx),
|
self.ap.plugin_mgr.install_plugin(data['source'], task_context=ctx),
|
||||||
kind='plugin-operation',
|
kind='plugin-operation',
|
||||||
name='plugin-install-github',
|
name='plugin-install-github',
|
||||||
label=f'安装插件 ...{short_source_str}',
|
label=f'Installing plugin from github ...{short_source_str}',
|
||||||
|
context=ctx,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.success(data={'task_id': wrapper.id})
|
||||||
|
|
||||||
|
@self.route('/install/marketplace', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||||
|
async def _() -> str:
|
||||||
|
data = await quart.request.json
|
||||||
|
|
||||||
|
ctx = taskmgr.TaskContext.new()
|
||||||
|
wrapper = self.ap.task_mgr.create_user_task(
|
||||||
|
self.ap.plugin_connector.install_plugin(PluginInstallSource.MARKETPLACE, data, task_context=ctx),
|
||||||
|
kind='plugin-operation',
|
||||||
|
name='plugin-install-marketplace',
|
||||||
|
label=f'Installing plugin from marketplace ...{data}',
|
||||||
|
context=ctx,
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.success(data={'task_id': wrapper.id})
|
||||||
|
|
||||||
|
@self.route('/install/local', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||||
|
async def _() -> str:
|
||||||
|
file = (await quart.request.files).get('file')
|
||||||
|
if file is None:
|
||||||
|
return self.http_status(400, -1, 'file is required')
|
||||||
|
|
||||||
|
file_bytes = file.read()
|
||||||
|
|
||||||
|
file_base64 = base64.b64encode(file_bytes).decode('utf-8')
|
||||||
|
|
||||||
|
data = {
|
||||||
|
'plugin_file': file_base64,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx = taskmgr.TaskContext.new()
|
||||||
|
wrapper = self.ap.task_mgr.create_user_task(
|
||||||
|
self.ap.plugin_connector.install_plugin(PluginInstallSource.LOCAL, data, task_context=ctx),
|
||||||
|
kind='plugin-operation',
|
||||||
|
name='plugin-install-local',
|
||||||
|
label=f'Installing plugin from local ...{file.filename}',
|
||||||
context=ctx,
|
context=ctx,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,11 @@ class SystemRouterGroup(group.RouterGroup):
|
|||||||
'version': constants.semantic_version,
|
'version': constants.semantic_version,
|
||||||
'debug': constants.debug_mode,
|
'debug': constants.debug_mode,
|
||||||
'enabled_platform_count': len(self.ap.platform_mgr.get_running_adapters()),
|
'enabled_platform_count': len(self.ap.platform_mgr.get_running_adapters()),
|
||||||
|
'cloud_service_url': (
|
||||||
|
self.ap.instance_config.data['plugin']['cloud_service_url']
|
||||||
|
if 'cloud_service_url' in self.ap.instance_config.data['plugin']
|
||||||
|
else 'https://space.langbot.app'
|
||||||
|
),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -35,16 +40,7 @@ class SystemRouterGroup(group.RouterGroup):
|
|||||||
|
|
||||||
return self.success(data=task.to_dict())
|
return self.success(data=task.to_dict())
|
||||||
|
|
||||||
@self.route('/reload', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
@self.route('/debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||||
async def _() -> str:
|
|
||||||
json_data = await quart.request.json
|
|
||||||
|
|
||||||
scope = json_data.get('scope')
|
|
||||||
|
|
||||||
await self.ap.reload(scope=scope)
|
|
||||||
return self.success()
|
|
||||||
|
|
||||||
@self.route('/_debug/exec', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
|
||||||
async def _() -> str:
|
async def _() -> str:
|
||||||
if not constants.debug_mode:
|
if not constants.debug_mode:
|
||||||
return self.http_status(403, 403, 'Forbidden')
|
return self.http_status(403, 403, 'Forbidden')
|
||||||
@@ -54,3 +50,39 @@ class SystemRouterGroup(group.RouterGroup):
|
|||||||
ap = self.ap
|
ap = self.ap
|
||||||
|
|
||||||
return self.success(data=exec(py_code, {'ap': ap}))
|
return self.success(data=exec(py_code, {'ap': ap}))
|
||||||
|
|
||||||
|
@self.route('/debug/tools/call', methods=['POST'], auth_type=group.AuthType.USER_TOKEN)
|
||||||
|
async def _() -> str:
|
||||||
|
if not constants.debug_mode:
|
||||||
|
return self.http_status(403, 403, 'Forbidden')
|
||||||
|
|
||||||
|
data = await quart.request.json
|
||||||
|
|
||||||
|
return self.success(
|
||||||
|
data=await self.ap.tool_mgr.execute_func_call(data['tool_name'], data['tool_parameters'])
|
||||||
|
)
|
||||||
|
|
||||||
|
@self.route(
|
||||||
|
'/debug/plugin/action',
|
||||||
|
methods=['POST'],
|
||||||
|
auth_type=group.AuthType.USER_TOKEN,
|
||||||
|
)
|
||||||
|
async def _() -> str:
|
||||||
|
if not constants.debug_mode:
|
||||||
|
return self.http_status(403, 403, 'Forbidden')
|
||||||
|
|
||||||
|
data = await quart.request.json
|
||||||
|
|
||||||
|
class AnoymousAction:
|
||||||
|
value = 'anonymous_action'
|
||||||
|
|
||||||
|
def __init__(self, value: str):
|
||||||
|
self.value = value
|
||||||
|
|
||||||
|
resp = await self.ap.plugin_connector.handler.call_action(
|
||||||
|
AnoymousAction(data['action']),
|
||||||
|
data['data'],
|
||||||
|
timeout=data.get('timeout', 10),
|
||||||
|
)
|
||||||
|
|
||||||
|
return self.success(data=resp)
|
||||||
|
|||||||
@@ -17,15 +17,19 @@ class BotService:
|
|||||||
def __init__(self, ap: app.Application) -> None:
|
def __init__(self, ap: app.Application) -> None:
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
|
|
||||||
async def get_bots(self) -> list[dict]:
|
async def get_bots(self, include_secret: bool = True) -> list[dict]:
|
||||||
"""获取所有机器人"""
|
"""获取所有机器人"""
|
||||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot))
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_bot.Bot))
|
||||||
|
|
||||||
bots = result.all()
|
bots = result.all()
|
||||||
|
|
||||||
return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot) for bot in bots]
|
masked_columns = []
|
||||||
|
if not include_secret:
|
||||||
|
masked_columns = ['adapter_config']
|
||||||
|
|
||||||
async def get_bot(self, bot_uuid: str) -> dict | None:
|
return [self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot, masked_columns) for bot in bots]
|
||||||
|
|
||||||
|
async def get_bot(self, bot_uuid: str, include_secret: bool = True) -> dict | None:
|
||||||
"""获取机器人"""
|
"""获取机器人"""
|
||||||
result = await self.ap.persistence_mgr.execute_async(
|
result = await self.ap.persistence_mgr.execute_async(
|
||||||
sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
|
sqlalchemy.select(persistence_bot.Bot).where(persistence_bot.Bot.uuid == bot_uuid)
|
||||||
@@ -36,7 +40,27 @@ class BotService:
|
|||||||
if bot is None:
|
if bot is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot)
|
masked_columns = []
|
||||||
|
if not include_secret:
|
||||||
|
masked_columns = ['adapter_config']
|
||||||
|
|
||||||
|
return self.ap.persistence_mgr.serialize_model(persistence_bot.Bot, bot, masked_columns)
|
||||||
|
|
||||||
|
async def get_runtime_bot_info(self, bot_uuid: str, include_secret: bool = True) -> dict:
|
||||||
|
"""获取机器人运行时信息"""
|
||||||
|
persistence_bot = await self.get_bot(bot_uuid, include_secret)
|
||||||
|
if persistence_bot is None:
|
||||||
|
raise Exception('Bot not found')
|
||||||
|
|
||||||
|
adapter_runtime_values = {}
|
||||||
|
|
||||||
|
runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
|
||||||
|
if runtime_bot is not None:
|
||||||
|
adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id
|
||||||
|
|
||||||
|
persistence_bot['adapter_runtime_values'] = adapter_runtime_values
|
||||||
|
|
||||||
|
return persistence_bot
|
||||||
|
|
||||||
async def create_bot(self, bot_data: dict) -> str:
|
async def create_bot(self, bot_data: dict) -> str:
|
||||||
"""创建机器人"""
|
"""创建机器人"""
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ from ....core import app
|
|||||||
from ....entity.persistence import model as persistence_model
|
from ....entity.persistence import model as persistence_model
|
||||||
from ....entity.persistence import pipeline as persistence_pipeline
|
from ....entity.persistence import pipeline as persistence_pipeline
|
||||||
from ....provider.modelmgr import requester as model_requester
|
from ....provider.modelmgr import requester as model_requester
|
||||||
from ....provider import entities as llm_entities
|
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||||
|
|
||||||
|
|
||||||
class ModelsService:
|
class ModelsService:
|
||||||
@@ -16,11 +16,19 @@ class ModelsService:
|
|||||||
def __init__(self, ap: app.Application) -> None:
|
def __init__(self, ap: app.Application) -> None:
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
|
|
||||||
async def get_llm_models(self) -> list[dict]:
|
async def get_llm_models(self, include_secret: bool = True) -> list[dict]:
|
||||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_model.LLMModel))
|
||||||
|
|
||||||
models = result.all()
|
models = result.all()
|
||||||
return [self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model) for model in models]
|
|
||||||
|
masked_columns = []
|
||||||
|
if not include_secret:
|
||||||
|
masked_columns = ['api_keys']
|
||||||
|
|
||||||
|
return [
|
||||||
|
self.ap.persistence_mgr.serialize_model(persistence_model.LLMModel, model, masked_columns)
|
||||||
|
for model in models
|
||||||
|
]
|
||||||
|
|
||||||
async def create_llm_model(self, model_data: dict) -> str:
|
async def create_llm_model(self, model_data: dict) -> str:
|
||||||
model_data['uuid'] = str(uuid.uuid4())
|
model_data['uuid'] = str(uuid.uuid4())
|
||||||
@@ -99,7 +107,7 @@ class ModelsService:
|
|||||||
await runtime_llm_model.requester.invoke_llm(
|
await runtime_llm_model.requester.invoke_llm(
|
||||||
query=None,
|
query=None,
|
||||||
model=runtime_llm_model,
|
model=runtime_llm_model,
|
||||||
messages=[llm_entities.Message(role='user', content='Hello, world!')],
|
messages=[provider_message.Message(role='user', content='Hello, world!')],
|
||||||
funcs=[],
|
funcs=[],
|
||||||
extra_args={},
|
extra_args={},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -2,9 +2,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ..core import app, entities as core_entities
|
from ..core import app
|
||||||
from . import entities, operator, errors
|
from . import operator
|
||||||
from ..utils import importutil
|
from ..utils import importutil
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||||
|
|
||||||
# 引入所有算子以便注册
|
# 引入所有算子以便注册
|
||||||
from . import operators
|
from . import operators
|
||||||
@@ -13,13 +16,11 @@ importutil.import_modules_in_pkg(operators)
|
|||||||
|
|
||||||
|
|
||||||
class CommandManager:
|
class CommandManager:
|
||||||
"""命令管理器"""
|
|
||||||
|
|
||||||
ap: app.Application
|
ap: app.Application
|
||||||
|
|
||||||
cmd_list: list[operator.CommandOperator]
|
cmd_list: list[operator.CommandOperator]
|
||||||
"""
|
"""
|
||||||
运行时命令列表,扁平存储,各个对象包含对应的子节点引用
|
Runtime command list, flat storage, each object contains a reference to the corresponding child node
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, ap: app.Application):
|
def __init__(self, ap: app.Application):
|
||||||
@@ -55,43 +56,28 @@ class CommandManager:
|
|||||||
|
|
||||||
async def _execute(
|
async def _execute(
|
||||||
self,
|
self,
|
||||||
context: entities.ExecuteContext,
|
context: command_context.ExecuteContext,
|
||||||
operator_list: list[operator.CommandOperator],
|
operator_list: list[operator.CommandOperator],
|
||||||
operator: operator.CommandOperator = None,
|
operator: operator.CommandOperator = None,
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
"""执行命令"""
|
"""执行命令"""
|
||||||
|
|
||||||
found = False
|
command_list = await self.ap.plugin_connector.list_commands()
|
||||||
if len(context.crt_params) > 0: # 查找下一个参数是否对应此节点的某个子节点名
|
|
||||||
for oper in operator_list:
|
|
||||||
if (context.crt_params[0] == oper.name or context.crt_params[0] in oper.alias) and (
|
|
||||||
oper.parent_class is None or oper.parent_class == operator.__class__
|
|
||||||
):
|
|
||||||
found = True
|
|
||||||
|
|
||||||
context.crt_command = context.crt_params[0]
|
for command in command_list:
|
||||||
context.crt_params = context.crt_params[1:]
|
if command.metadata.name == context.command:
|
||||||
|
async for ret in self.ap.plugin_connector.execute_command(context):
|
||||||
async for ret in self._execute(context, oper.children, oper):
|
yield ret
|
||||||
yield ret
|
break
|
||||||
break
|
else:
|
||||||
|
yield command_context.CommandReturn(error=command_errors.CommandNotFoundError(context.command))
|
||||||
if not found: # 如果下一个参数未在此节点的子节点中找到,则执行此节点或者报错
|
|
||||||
if operator is None:
|
|
||||||
yield entities.CommandReturn(error=errors.CommandNotFoundError(context.crt_params[0]))
|
|
||||||
else:
|
|
||||||
if operator.lowest_privilege > context.privilege:
|
|
||||||
yield entities.CommandReturn(error=errors.CommandPrivilegeError(operator.name))
|
|
||||||
else:
|
|
||||||
async for ret in operator.execute(context):
|
|
||||||
yield ret
|
|
||||||
|
|
||||||
async def execute(
|
async def execute(
|
||||||
self,
|
self,
|
||||||
command_text: str,
|
command_text: str,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
session: core_entities.Session,
|
session: provider_session.Session,
|
||||||
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
"""执行命令"""
|
"""执行命令"""
|
||||||
|
|
||||||
privilege = 1
|
privilege = 1
|
||||||
@@ -99,8 +85,8 @@ class CommandManager:
|
|||||||
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
|
if f'{query.launcher_type.value}_{query.launcher_id}' in self.ap.instance_config.data['admins']:
|
||||||
privilege = 2
|
privilege = 2
|
||||||
|
|
||||||
ctx = entities.ExecuteContext(
|
ctx = command_context.ExecuteContext(
|
||||||
query=query,
|
query_id=query.query_id,
|
||||||
session=session,
|
session=session,
|
||||||
command_text=command_text,
|
command_text=command_text,
|
||||||
command='',
|
command='',
|
||||||
@@ -110,5 +96,9 @@ class CommandManager:
|
|||||||
privilege=privilege,
|
privilege=privilege,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
ctx.command = ctx.params[0]
|
||||||
|
|
||||||
|
ctx.shift()
|
||||||
|
|
||||||
async for ret in self._execute(ctx, self.cmd_list):
|
async for ret in self._execute(ctx, self.cmd_list):
|
||||||
yield ret
|
yield ret
|
||||||
|
|||||||
@@ -1,74 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
import pydantic.v1 as pydantic
|
|
||||||
|
|
||||||
from ..core import entities as core_entities
|
|
||||||
from . import errors
|
|
||||||
from ..platform.types import message as platform_message
|
|
||||||
|
|
||||||
|
|
||||||
class CommandReturn(pydantic.BaseModel):
|
|
||||||
"""命令返回值"""
|
|
||||||
|
|
||||||
text: typing.Optional[str] = None
|
|
||||||
"""文本
|
|
||||||
"""
|
|
||||||
|
|
||||||
image: typing.Optional[platform_message.Image] = None
|
|
||||||
"""弃用"""
|
|
||||||
|
|
||||||
image_url: typing.Optional[str] = None
|
|
||||||
"""图片链接
|
|
||||||
"""
|
|
||||||
|
|
||||||
error: typing.Optional[errors.CommandError] = None
|
|
||||||
"""错误
|
|
||||||
"""
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
|
|
||||||
class ExecuteContext(pydantic.BaseModel):
|
|
||||||
"""单次命令执行上下文"""
|
|
||||||
|
|
||||||
query: core_entities.Query
|
|
||||||
"""本次消息的请求对象"""
|
|
||||||
|
|
||||||
session: core_entities.Session
|
|
||||||
"""本次消息所属的会话对象"""
|
|
||||||
|
|
||||||
command_text: str
|
|
||||||
"""命令完整文本"""
|
|
||||||
|
|
||||||
command: str
|
|
||||||
"""命令名称"""
|
|
||||||
|
|
||||||
crt_command: str
|
|
||||||
"""当前命令
|
|
||||||
|
|
||||||
多级命令中crt_command为当前命令,command为根命令。
|
|
||||||
例如:!plugin on Webwlkr
|
|
||||||
处理到plugin时,command为plugin,crt_command为plugin
|
|
||||||
处理到on时,command为plugin,crt_command为on
|
|
||||||
"""
|
|
||||||
|
|
||||||
params: list[str]
|
|
||||||
"""命令参数
|
|
||||||
|
|
||||||
整个命令以空格分割后的参数列表
|
|
||||||
"""
|
|
||||||
|
|
||||||
crt_params: list[str]
|
|
||||||
"""当前命令参数
|
|
||||||
|
|
||||||
多级命令中crt_params为当前命令参数,params为根命令参数。
|
|
||||||
例如:!plugin on Webwlkr
|
|
||||||
处理到plugin时,params为['on', 'Webwlkr'],crt_params为['on', 'Webwlkr']
|
|
||||||
处理到on时,params为['on', 'Webwlkr'],crt_params为['Webwlkr']
|
|
||||||
"""
|
|
||||||
|
|
||||||
privilege: int
|
|
||||||
"""发起人权限"""
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
class CommandError(Exception):
|
|
||||||
def __init__(self, message: str = None):
|
|
||||||
self.message = message
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return self.message
|
|
||||||
|
|
||||||
|
|
||||||
class CommandNotFoundError(CommandError):
|
|
||||||
def __init__(self, message: str = None):
|
|
||||||
super().__init__('未知命令: ' + message)
|
|
||||||
|
|
||||||
|
|
||||||
class CommandPrivilegeError(CommandError):
|
|
||||||
def __init__(self, message: str = None):
|
|
||||||
super().__init__('权限不足: ' + message)
|
|
||||||
|
|
||||||
|
|
||||||
class ParamNotEnoughError(CommandError):
|
|
||||||
def __init__(self, message: str = None):
|
|
||||||
super().__init__('参数不足: ' + message)
|
|
||||||
|
|
||||||
|
|
||||||
class CommandOperationError(CommandError):
|
|
||||||
def __init__(self, message: str = None):
|
|
||||||
super().__init__('操作失败: ' + message)
|
|
||||||
@@ -4,7 +4,7 @@ import typing
|
|||||||
import abc
|
import abc
|
||||||
|
|
||||||
from ..core import app
|
from ..core import app
|
||||||
from . import entities
|
from langbot_plugin.api.entities.builtin.command import context as command_context
|
||||||
|
|
||||||
|
|
||||||
preregistered_operators: list[typing.Type[CommandOperator]] = []
|
preregistered_operators: list[typing.Type[CommandOperator]] = []
|
||||||
@@ -95,16 +95,18 @@ class CommandOperator(metaclass=abc.ABCMeta):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
"""实现此方法以执行命令
|
"""实现此方法以执行命令
|
||||||
|
|
||||||
支持多次yield以返回多个结果。
|
支持多次yield以返回多个结果。
|
||||||
例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。
|
例如:一个安装插件的命令,可能会有下载、解压、安装等多个步骤,每个步骤都可以返回一个结果。
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
context (entities.ExecuteContext): 命令执行上下文
|
context (command_context.ExecuteContext): 命令执行上下文
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
entities.CommandReturn: 命令返回封装
|
command_context.CommandReturn: 命令返回封装
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -2,14 +2,17 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .. import operator, entities, errors
|
from .. import operator
|
||||||
|
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='cmd', help='显示命令列表', usage='!cmd\n!cmd <命令名称>')
|
@operator.operator_class(name='cmd', help='显示命令列表', usage='!cmd\n!cmd <命令名称>')
|
||||||
class CmdOperator(operator.CommandOperator):
|
class CmdOperator(operator.CommandOperator):
|
||||||
"""命令列表"""
|
"""命令列表"""
|
||||||
|
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
"""执行"""
|
"""执行"""
|
||||||
if len(context.crt_params) == 0:
|
if len(context.crt_params) == 0:
|
||||||
reply_str = '当前所有命令: \n\n'
|
reply_str = '当前所有命令: \n\n'
|
||||||
@@ -20,7 +23,7 @@ class CmdOperator(operator.CommandOperator):
|
|||||||
|
|
||||||
reply_str += '\n使用 !cmd <命令名称> 查看命令的详细帮助'
|
reply_str += '\n使用 !cmd <命令名称> 查看命令的详细帮助'
|
||||||
|
|
||||||
yield entities.CommandReturn(text=reply_str.strip())
|
yield command_context.CommandReturn(text=reply_str.strip())
|
||||||
|
|
||||||
else:
|
else:
|
||||||
cmd_name = context.crt_params[0]
|
cmd_name = context.crt_params[0]
|
||||||
@@ -33,9 +36,9 @@ class CmdOperator(operator.CommandOperator):
|
|||||||
break
|
break
|
||||||
|
|
||||||
if cmd is None:
|
if cmd is None:
|
||||||
yield entities.CommandReturn(error=errors.CommandNotFoundError(cmd_name))
|
yield command_context.CommandReturn(error=command_errors.CommandNotFoundError(cmd_name))
|
||||||
else:
|
else:
|
||||||
reply_str = f'{cmd.name}: {cmd.help}\n\n'
|
reply_str = f'{cmd.name}: {cmd.help}\n\n'
|
||||||
reply_str += f'使用方法: \n{cmd.usage}'
|
reply_str += f'使用方法: \n{cmd.usage}'
|
||||||
|
|
||||||
yield entities.CommandReturn(text=reply_str.strip())
|
yield command_context.CommandReturn(text=reply_str.strip())
|
||||||
|
|||||||
@@ -2,23 +2,26 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .. import operator, entities, errors
|
from .. import operator
|
||||||
|
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all')
|
@operator.operator_class(name='del', help='删除当前会话的历史记录', usage='!del <序号>\n!del all')
|
||||||
class DelOperator(operator.CommandOperator):
|
class DelOperator(operator.CommandOperator):
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
if context.session.conversations:
|
if context.session.conversations:
|
||||||
delete_index = 0
|
delete_index = 0
|
||||||
if len(context.crt_params) > 0:
|
if len(context.crt_params) > 0:
|
||||||
try:
|
try:
|
||||||
delete_index = int(context.crt_params[0])
|
delete_index = int(context.crt_params[0])
|
||||||
except Exception:
|
except Exception:
|
||||||
yield entities.CommandReturn(error=errors.CommandOperationError('索引必须是整数'))
|
yield command_context.CommandReturn(error=command_errors.CommandOperationError('索引必须是整数'))
|
||||||
return
|
return
|
||||||
|
|
||||||
if delete_index < 0 or delete_index >= len(context.session.conversations):
|
if delete_index < 0 or delete_index >= len(context.session.conversations):
|
||||||
yield entities.CommandReturn(error=errors.CommandOperationError('索引超出范围'))
|
yield command_context.CommandReturn(error=command_errors.CommandOperationError('索引超出范围'))
|
||||||
return
|
return
|
||||||
|
|
||||||
# 倒序
|
# 倒序
|
||||||
@@ -29,15 +32,17 @@ class DelOperator(operator.CommandOperator):
|
|||||||
|
|
||||||
del context.session.conversations[to_delete_index]
|
del context.session.conversations[to_delete_index]
|
||||||
|
|
||||||
yield entities.CommandReturn(text=f'已删除对话: {delete_index}')
|
yield command_context.CommandReturn(text=f'已删除对话: {delete_index}')
|
||||||
else:
|
else:
|
||||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='all', help='删除此会话的所有历史记录', parent_class=DelOperator)
|
@operator.operator_class(name='all', help='删除此会话的所有历史记录', parent_class=DelOperator)
|
||||||
class DelAllOperator(operator.CommandOperator):
|
class DelAllOperator(operator.CommandOperator):
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
context.session.conversations = []
|
context.session.conversations = []
|
||||||
context.session.using_conversation = None
|
context.session.using_conversation = None
|
||||||
|
|
||||||
yield entities.CommandReturn(text='已删除所有对话')
|
yield command_context.CommandReturn(text='已删除所有对话')
|
||||||
|
|||||||
@@ -1,19 +1,20 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
from .. import operator, entities
|
from .. import operator
|
||||||
|
from langbot_plugin.api.entities.builtin.command import context as command_context
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func')
|
@operator.operator_class(name='func', help='查看所有已注册的内容函数', usage='!func')
|
||||||
class FuncOperator(operator.CommandOperator):
|
class FuncOperator(operator.CommandOperator):
|
||||||
async def execute(self, context: entities.ExecuteContext) -> AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
reply_str = '当前已启用的内容函数: \n\n'
|
reply_str = '当前已启用的内容函数: \n\n'
|
||||||
|
|
||||||
index = 1
|
index = 1
|
||||||
|
|
||||||
all_functions = await self.ap.tool_mgr.get_all_functions(
|
all_functions = await self.ap.tool_mgr.get_all_tools()
|
||||||
plugin_enabled=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
for func in all_functions:
|
for func in all_functions:
|
||||||
reply_str += '{}. {}:\n{}\n\n'.format(
|
reply_str += '{}. {}:\n{}\n\n'.format(
|
||||||
@@ -23,4 +24,4 @@ class FuncOperator(operator.CommandOperator):
|
|||||||
)
|
)
|
||||||
index += 1
|
index += 1
|
||||||
|
|
||||||
yield entities.CommandReturn(text=reply_str)
|
yield command_context.CommandReturn(text=reply_str)
|
||||||
|
|||||||
@@ -2,14 +2,17 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .. import operator, entities
|
from .. import operator
|
||||||
|
from langbot_plugin.api.entities.builtin.command import context as command_context
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>')
|
@operator.operator_class(name='help', help='显示帮助', usage='!help\n!help <命令名称>')
|
||||||
class HelpOperator(operator.CommandOperator):
|
class HelpOperator(operator.CommandOperator):
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接:https://langbot.app'
|
help = 'LangBot - 大语言模型原生即时通信机器人平台\n链接:https://langbot.app'
|
||||||
|
|
||||||
help += '\n发送命令 !cmd 可查看命令列表'
|
help += '\n发送命令 !cmd 可查看命令列表'
|
||||||
|
|
||||||
yield entities.CommandReturn(text=help)
|
yield command_context.CommandReturn(text=help)
|
||||||
|
|||||||
@@ -3,26 +3,31 @@ from __future__ import annotations
|
|||||||
import typing
|
import typing
|
||||||
|
|
||||||
|
|
||||||
from .. import operator, entities, errors
|
from .. import operator
|
||||||
|
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='last', help='切换到前一个对话', usage='!last')
|
@operator.operator_class(name='last', help='切换到前一个对话', usage='!last')
|
||||||
class LastOperator(operator.CommandOperator):
|
class LastOperator(operator.CommandOperator):
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
if context.session.conversations:
|
if context.session.conversations:
|
||||||
# 找到当前会话的上一个会话
|
# 找到当前会话的上一个会话
|
||||||
for index in range(len(context.session.conversations) - 1, -1, -1):
|
for index in range(len(context.session.conversations) - 1, -1, -1):
|
||||||
if context.session.conversations[index] == context.session.using_conversation:
|
if context.session.conversations[index] == context.session.using_conversation:
|
||||||
if index == 0:
|
if index == 0:
|
||||||
yield entities.CommandReturn(error=errors.CommandOperationError('已经是第一个对话了'))
|
yield command_context.CommandReturn(
|
||||||
|
error=command_errors.CommandOperationError('已经是第一个对话了')
|
||||||
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
context.session.using_conversation = context.session.conversations[index - 1]
|
context.session.using_conversation = context.session.conversations[index - 1]
|
||||||
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
|
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
|
||||||
yield entities.CommandReturn(
|
yield command_context.CommandReturn(
|
||||||
text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}'
|
text=f'已切换到上一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].readable_str()}'
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
|
||||||
|
|||||||
@@ -2,19 +2,22 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .. import operator, entities, errors
|
from .. import operator
|
||||||
|
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>')
|
@operator.operator_class(name='list', help='列出此会话中的所有历史对话', usage='!list\n!list <页码>')
|
||||||
class ListOperator(operator.CommandOperator):
|
class ListOperator(operator.CommandOperator):
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
page = 0
|
page = 0
|
||||||
|
|
||||||
if len(context.crt_params) > 0:
|
if len(context.crt_params) > 0:
|
||||||
try:
|
try:
|
||||||
page = int(context.crt_params[0] - 1)
|
page = int(context.crt_params[0] - 1)
|
||||||
except Exception:
|
except Exception:
|
||||||
yield entities.CommandReturn(error=errors.CommandOperationError('页码应为整数'))
|
yield command_context.CommandReturn(error=command_errors.CommandOperationError('页码应为整数'))
|
||||||
return
|
return
|
||||||
|
|
||||||
record_per_page = 10
|
record_per_page = 10
|
||||||
@@ -45,4 +48,4 @@ class ListOperator(operator.CommandOperator):
|
|||||||
else:
|
else:
|
||||||
content += f'\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else "无内容"}'
|
content += f'\n当前会话: {using_conv_index} {context.session.using_conversation.create_time.strftime("%Y-%m-%d %H:%M:%S")}: {context.session.using_conversation.messages[0].readable_str() if len(context.session.using_conversation.messages) > 0 else "无内容"}'
|
||||||
|
|
||||||
yield entities.CommandReturn(text=f'第 {page + 1} 页 (时间倒序):\n{content}')
|
yield command_context.CommandReturn(text=f'第 {page + 1} 页 (时间倒序):\n{content}')
|
||||||
|
|||||||
@@ -2,26 +2,31 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .. import operator, entities, errors
|
from .. import operator
|
||||||
|
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='next', help='切换到后一个对话', usage='!next')
|
@operator.operator_class(name='next', help='切换到后一个对话', usage='!next')
|
||||||
class NextOperator(operator.CommandOperator):
|
class NextOperator(operator.CommandOperator):
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
if context.session.conversations:
|
if context.session.conversations:
|
||||||
# 找到当前会话的下一个会话
|
# 找到当前会话的下一个会话
|
||||||
for index in range(len(context.session.conversations)):
|
for index in range(len(context.session.conversations)):
|
||||||
if context.session.conversations[index] == context.session.using_conversation:
|
if context.session.conversations[index] == context.session.using_conversation:
|
||||||
if index == len(context.session.conversations) - 1:
|
if index == len(context.session.conversations) - 1:
|
||||||
yield entities.CommandReturn(error=errors.CommandOperationError('已经是最后一个对话了'))
|
yield command_context.CommandReturn(
|
||||||
|
error=command_errors.CommandOperationError('已经是最后一个对话了')
|
||||||
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
context.session.using_conversation = context.session.conversations[index + 1]
|
context.session.using_conversation = context.session.conversations[index + 1]
|
||||||
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
|
time_str = context.session.using_conversation.create_time.strftime('%Y-%m-%d %H:%M:%S')
|
||||||
|
|
||||||
yield entities.CommandReturn(
|
yield command_context.CommandReturn(
|
||||||
text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}'
|
text=f'已切换到后一个对话: {index} {time_str}: {context.session.using_conversation.messages[0].content}'
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
else:
|
else:
|
||||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ from __future__ import annotations
|
|||||||
import typing
|
import typing
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from .. import operator, entities, errors
|
from .. import operator
|
||||||
|
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(
|
@operator.operator_class(
|
||||||
@@ -11,7 +12,9 @@ from .. import operator, entities, errors
|
|||||||
usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>',
|
usage='!plugin\n!plugin get <插件仓库地址>\n!plugin update\n!plugin del <插件名>\n!plugin on <插件名>\n!plugin off <插件名>',
|
||||||
)
|
)
|
||||||
class PluginOperator(operator.CommandOperator):
|
class PluginOperator(operator.CommandOperator):
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
plugin_list = self.ap.plugin_mgr.plugins()
|
plugin_list = self.ap.plugin_mgr.plugins()
|
||||||
reply_str = '所有插件({}):\n'.format(len(plugin_list))
|
reply_str = '所有插件({}):\n'.format(len(plugin_list))
|
||||||
idx = 0
|
idx = 0
|
||||||
@@ -27,32 +30,36 @@ class PluginOperator(operator.CommandOperator):
|
|||||||
|
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
yield entities.CommandReturn(text=reply_str)
|
yield command_context.CommandReturn(text=reply_str)
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='get', help='安装插件', privilege=2, parent_class=PluginOperator)
|
@operator.operator_class(name='get', help='安装插件', privilege=2, parent_class=PluginOperator)
|
||||||
class PluginGetOperator(operator.CommandOperator):
|
class PluginGetOperator(operator.CommandOperator):
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
if len(context.crt_params) == 0:
|
if len(context.crt_params) == 0:
|
||||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件仓库地址'))
|
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件仓库地址'))
|
||||||
else:
|
else:
|
||||||
repo = context.crt_params[0]
|
repo = context.crt_params[0]
|
||||||
|
|
||||||
yield entities.CommandReturn(text='正在安装插件...')
|
yield command_context.CommandReturn(text='正在安装插件...')
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.ap.plugin_mgr.install_plugin(repo)
|
await self.ap.plugin_mgr.install_plugin(repo)
|
||||||
yield entities.CommandReturn(text='插件安装成功,请重启程序以加载插件')
|
yield command_context.CommandReturn(text='插件安装成功,请重启程序以加载插件')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield entities.CommandReturn(error=errors.CommandError('插件安装失败: ' + str(e)))
|
yield command_context.CommandReturn(error=command_errors.CommandError('插件安装失败: ' + str(e)))
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='update', help='更新插件', privilege=2, parent_class=PluginOperator)
|
@operator.operator_class(name='update', help='更新插件', privilege=2, parent_class=PluginOperator)
|
||||||
class PluginUpdateOperator(operator.CommandOperator):
|
class PluginUpdateOperator(operator.CommandOperator):
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
if len(context.crt_params) == 0:
|
if len(context.crt_params) == 0:
|
||||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
|
||||||
else:
|
else:
|
||||||
plugin_name = context.crt_params[0]
|
plugin_name = context.crt_params[0]
|
||||||
|
|
||||||
@@ -60,24 +67,26 @@ class PluginUpdateOperator(operator.CommandOperator):
|
|||||||
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
|
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
|
||||||
|
|
||||||
if plugin_container is not None:
|
if plugin_container is not None:
|
||||||
yield entities.CommandReturn(text='正在更新插件...')
|
yield command_context.CommandReturn(text='正在更新插件...')
|
||||||
await self.ap.plugin_mgr.update_plugin(plugin_name)
|
await self.ap.plugin_mgr.update_plugin(plugin_name)
|
||||||
yield entities.CommandReturn(text='插件更新成功,请重启程序以加载插件')
|
yield command_context.CommandReturn(text='插件更新成功,请重启程序以加载插件')
|
||||||
else:
|
else:
|
||||||
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: 未找到插件'))
|
yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: 未找到插件'))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
|
yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e)))
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator)
|
@operator.operator_class(name='all', help='更新所有插件', privilege=2, parent_class=PluginUpdateOperator)
|
||||||
class PluginUpdateAllOperator(operator.CommandOperator):
|
class PluginUpdateAllOperator(operator.CommandOperator):
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
try:
|
try:
|
||||||
plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()]
|
plugins = [p.plugin_name for p in self.ap.plugin_mgr.plugins()]
|
||||||
|
|
||||||
if plugins:
|
if plugins:
|
||||||
yield entities.CommandReturn(text='正在更新插件...')
|
yield command_context.CommandReturn(text='正在更新插件...')
|
||||||
updated = []
|
updated = []
|
||||||
try:
|
try:
|
||||||
for plugin_name in plugins:
|
for plugin_name in plugins:
|
||||||
@@ -85,20 +94,22 @@ class PluginUpdateAllOperator(operator.CommandOperator):
|
|||||||
updated.append(plugin_name)
|
updated.append(plugin_name)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
|
yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e)))
|
||||||
yield entities.CommandReturn(text='已更新插件: {}'.format(', '.join(updated)))
|
yield command_context.CommandReturn(text='已更新插件: {}'.format(', '.join(updated)))
|
||||||
else:
|
else:
|
||||||
yield entities.CommandReturn(text='没有可更新的插件')
|
yield command_context.CommandReturn(text='没有可更新的插件')
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield entities.CommandReturn(error=errors.CommandError('插件更新失败: ' + str(e)))
|
yield command_context.CommandReturn(error=command_errors.CommandError('插件更新失败: ' + str(e)))
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='del', help='删除插件', privilege=2, parent_class=PluginOperator)
|
@operator.operator_class(name='del', help='删除插件', privilege=2, parent_class=PluginOperator)
|
||||||
class PluginDelOperator(operator.CommandOperator):
|
class PluginDelOperator(operator.CommandOperator):
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
if len(context.crt_params) == 0:
|
if len(context.crt_params) == 0:
|
||||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
|
||||||
else:
|
else:
|
||||||
plugin_name = context.crt_params[0]
|
plugin_name = context.crt_params[0]
|
||||||
|
|
||||||
@@ -106,51 +117,55 @@ class PluginDelOperator(operator.CommandOperator):
|
|||||||
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
|
plugin_container = self.ap.plugin_mgr.get_plugin_by_name(plugin_name)
|
||||||
|
|
||||||
if plugin_container is not None:
|
if plugin_container is not None:
|
||||||
yield entities.CommandReturn(text='正在删除插件...')
|
yield command_context.CommandReturn(text='正在删除插件...')
|
||||||
await self.ap.plugin_mgr.uninstall_plugin(plugin_name)
|
await self.ap.plugin_mgr.uninstall_plugin(plugin_name)
|
||||||
yield entities.CommandReturn(text='插件删除成功,请重启程序以加载插件')
|
yield command_context.CommandReturn(text='插件删除成功,请重启程序以加载插件')
|
||||||
else:
|
else:
|
||||||
yield entities.CommandReturn(error=errors.CommandError('插件删除失败: 未找到插件'))
|
yield command_context.CommandReturn(error=command_errors.CommandError('插件删除失败: 未找到插件'))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield entities.CommandReturn(error=errors.CommandError('插件删除失败: ' + str(e)))
|
yield command_context.CommandReturn(error=command_errors.CommandError('插件删除失败: ' + str(e)))
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='on', help='启用插件', privilege=2, parent_class=PluginOperator)
|
@operator.operator_class(name='on', help='启用插件', privilege=2, parent_class=PluginOperator)
|
||||||
class PluginEnableOperator(operator.CommandOperator):
|
class PluginEnableOperator(operator.CommandOperator):
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
if len(context.crt_params) == 0:
|
if len(context.crt_params) == 0:
|
||||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
|
||||||
else:
|
else:
|
||||||
plugin_name = context.crt_params[0]
|
plugin_name = context.crt_params[0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True):
|
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, True):
|
||||||
yield entities.CommandReturn(text='已启用插件: {}'.format(plugin_name))
|
yield command_context.CommandReturn(text='已启用插件: {}'.format(plugin_name))
|
||||||
else:
|
else:
|
||||||
yield entities.CommandReturn(
|
yield command_context.CommandReturn(
|
||||||
error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
|
error=command_errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e)))
|
yield command_context.CommandReturn(error=command_errors.CommandError('插件状态修改失败: ' + str(e)))
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='off', help='禁用插件', privilege=2, parent_class=PluginOperator)
|
@operator.operator_class(name='off', help='禁用插件', privilege=2, parent_class=PluginOperator)
|
||||||
class PluginDisableOperator(operator.CommandOperator):
|
class PluginDisableOperator(operator.CommandOperator):
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
if len(context.crt_params) == 0:
|
if len(context.crt_params) == 0:
|
||||||
yield entities.CommandReturn(error=errors.ParamNotEnoughError('请提供插件名称'))
|
yield command_context.CommandReturn(error=command_errors.ParamNotEnoughError('请提供插件名称'))
|
||||||
else:
|
else:
|
||||||
plugin_name = context.crt_params[0]
|
plugin_name = context.crt_params[0]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False):
|
if await self.ap.plugin_mgr.update_plugin_switch(plugin_name, False):
|
||||||
yield entities.CommandReturn(text='已禁用插件: {}'.format(plugin_name))
|
yield command_context.CommandReturn(text='已禁用插件: {}'.format(plugin_name))
|
||||||
else:
|
else:
|
||||||
yield entities.CommandReturn(
|
yield command_context.CommandReturn(
|
||||||
error=errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
|
error=command_errors.CommandError('插件状态修改失败: 未找到插件 {}'.format(plugin_name))
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
yield entities.CommandReturn(error=errors.CommandError('插件状态修改失败: ' + str(e)))
|
yield command_context.CommandReturn(error=command_errors.CommandError('插件状态修改失败: ' + str(e)))
|
||||||
|
|||||||
@@ -2,19 +2,22 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .. import operator, entities, errors
|
from .. import operator
|
||||||
|
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt')
|
@operator.operator_class(name='prompt', help='查看当前对话的前文', usage='!prompt')
|
||||||
class PromptOperator(operator.CommandOperator):
|
class PromptOperator(operator.CommandOperator):
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
"""执行"""
|
"""执行"""
|
||||||
if context.session.using_conversation is None:
|
if context.session.using_conversation is None:
|
||||||
yield entities.CommandReturn(error=errors.CommandOperationError('当前没有对话'))
|
yield command_context.CommandReturn(error=command_errors.CommandOperationError('当前没有对话'))
|
||||||
else:
|
else:
|
||||||
reply_str = '当前对话所有内容:\n\n'
|
reply_str = '当前对话所有内容:\n\n'
|
||||||
|
|
||||||
for msg in context.session.using_conversation.messages:
|
for msg in context.session.using_conversation.messages:
|
||||||
reply_str += f'{msg.role}: {msg.content}\n'
|
reply_str += f'{msg.role}: {msg.content}\n'
|
||||||
|
|
||||||
yield entities.CommandReturn(text=reply_str)
|
yield command_context.CommandReturn(text=reply_str)
|
||||||
|
|||||||
@@ -2,15 +2,18 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .. import operator, entities, errors
|
from .. import operator
|
||||||
|
from langbot_plugin.api.entities.builtin.command import context as command_context, errors as command_errors
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='resend', help='重发当前会话的最后一条消息', usage='!resend')
|
@operator.operator_class(name='resend', help='重发当前会话的最后一条消息', usage='!resend')
|
||||||
class ResendOperator(operator.CommandOperator):
|
class ResendOperator(operator.CommandOperator):
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
# 回滚到最后一条用户message前
|
# 回滚到最后一条用户message前
|
||||||
if context.session.using_conversation is None:
|
if context.session.using_conversation is None:
|
||||||
yield entities.CommandReturn(error=errors.CommandError('当前没有对话'))
|
yield command_context.CommandReturn(error=command_errors.CommandError('当前没有对话'))
|
||||||
else:
|
else:
|
||||||
conv_msg = context.session.using_conversation.messages
|
conv_msg = context.session.using_conversation.messages
|
||||||
|
|
||||||
@@ -23,4 +26,4 @@ class ResendOperator(operator.CommandOperator):
|
|||||||
conv_msg.pop()
|
conv_msg.pop()
|
||||||
|
|
||||||
# 不重发了,提示用户已删除就行了
|
# 不重发了,提示用户已删除就行了
|
||||||
yield entities.CommandReturn(text='已删除最后一次请求记录')
|
yield command_context.CommandReturn(text='已删除最后一次请求记录')
|
||||||
|
|||||||
@@ -2,13 +2,16 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .. import operator, entities
|
from .. import operator
|
||||||
|
from langbot_plugin.api.entities.builtin.command import context as command_context
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='reset', help='重置当前会话', usage='!reset')
|
@operator.operator_class(name='reset', help='重置当前会话', usage='!reset')
|
||||||
class ResetOperator(operator.CommandOperator):
|
class ResetOperator(operator.CommandOperator):
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
"""执行"""
|
"""执行"""
|
||||||
context.session.using_conversation = None
|
context.session.using_conversation = None
|
||||||
|
|
||||||
yield entities.CommandReturn(text='已重置当前会话')
|
yield command_context.CommandReturn(text='已重置当前会话')
|
||||||
|
|||||||
@@ -1,11 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import typing
|
|
||||||
|
|
||||||
from .. import operator, entities
|
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='update', help='更新程序', usage='!update', privilege=2)
|
|
||||||
class UpdateCommand(operator.CommandOperator):
|
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
|
||||||
yield entities.CommandReturn(text='不再支持通过命令更新,请查看 LangBot 文档。')
|
|
||||||
@@ -2,12 +2,15 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from .. import operator, entities
|
from .. import operator
|
||||||
|
from langbot_plugin.api.entities.builtin.command import context as command_context
|
||||||
|
|
||||||
|
|
||||||
@operator.operator_class(name='version', help='显示版本信息', usage='!version')
|
@operator.operator_class(name='version', help='显示版本信息', usage='!version')
|
||||||
class VersionCommand(operator.CommandOperator):
|
class VersionCommand(operator.CommandOperator):
|
||||||
async def execute(self, context: entities.ExecuteContext) -> typing.AsyncGenerator[entities.CommandReturn, None]:
|
async def execute(
|
||||||
|
self, context: command_context.ExecuteContext
|
||||||
|
) -> typing.AsyncGenerator[command_context.CommandReturn, None]:
|
||||||
reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}'
|
reply_str = f'当前版本: \n{self.ap.ver_mgr.get_current_version()}'
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -16,4 +19,4 @@ class VersionCommand(operator.CommandOperator):
|
|||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
yield entities.CommandReturn(text=reply_str.strip())
|
yield command_context.CommandReturn(text=reply_str.strip())
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ from __future__ import annotations
|
|||||||
import logging
|
import logging
|
||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
import sys
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from ..platform import botmgr as im_mgr
|
from ..platform import botmgr as im_mgr
|
||||||
@@ -12,7 +11,7 @@ from ..provider.modelmgr import modelmgr as llm_model_mgr
|
|||||||
from ..provider.tools import toolmgr as llm_tool_mgr
|
from ..provider.tools import toolmgr as llm_tool_mgr
|
||||||
from ..config import manager as config_mgr
|
from ..config import manager as config_mgr
|
||||||
from ..command import cmdmgr
|
from ..command import cmdmgr
|
||||||
from ..plugin import manager as plugin_mgr
|
from ..plugin import connector as plugin_connector
|
||||||
from ..pipeline import pool
|
from ..pipeline import pool
|
||||||
from ..pipeline import controller, pipelinemgr
|
from ..pipeline import controller, pipelinemgr
|
||||||
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
|
from ..utils import version as version_mgr, proxy as proxy_mgr, announce as announce_mgr
|
||||||
@@ -75,7 +74,7 @@ class Application:
|
|||||||
|
|
||||||
# =========================
|
# =========================
|
||||||
|
|
||||||
plugin_mgr: plugin_mgr.PluginManager = None
|
plugin_connector: plugin_connector.PluginRuntimeConnector = None
|
||||||
|
|
||||||
query_pool: pool.QueryPool = None
|
query_pool: pool.QueryPool = None
|
||||||
|
|
||||||
@@ -117,7 +116,7 @@ class Application:
|
|||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
try:
|
try:
|
||||||
await self.plugin_mgr.initialize_plugins()
|
await self.plugin_connector.initialize_plugins()
|
||||||
|
|
||||||
# 后续可能会允许动态重启其他任务
|
# 后续可能会允许动态重启其他任务
|
||||||
# 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程
|
# 故为了防止程序在非 Ctrl-C 情况下退出,这里创建一个不会结束的协程
|
||||||
@@ -157,6 +156,9 @@ class Application:
|
|||||||
self.logger.error(f'应用运行致命异常: {e}')
|
self.logger.error(f'应用运行致命异常: {e}')
|
||||||
self.logger.debug(f'Traceback: {traceback.format_exc()}')
|
self.logger.debug(f'Traceback: {traceback.format_exc()}')
|
||||||
|
|
||||||
|
def dispose(self):
|
||||||
|
self.plugin_connector.dispose()
|
||||||
|
|
||||||
async def print_web_access_info(self):
|
async def print_web_access_info(self):
|
||||||
"""打印访问 webui 的提示"""
|
"""打印访问 webui 的提示"""
|
||||||
|
|
||||||
@@ -183,59 +185,3 @@ class Application:
|
|||||||
""".strip()
|
""".strip()
|
||||||
for line in tips.split('\n'):
|
for line in tips.split('\n'):
|
||||||
self.logger.info(line)
|
self.logger.info(line)
|
||||||
|
|
||||||
async def reload(
|
|
||||||
self,
|
|
||||||
scope: core_entities.LifecycleControlScope,
|
|
||||||
):
|
|
||||||
match scope:
|
|
||||||
case core_entities.LifecycleControlScope.PLATFORM.value:
|
|
||||||
self.logger.info('执行热重载 scope=' + scope)
|
|
||||||
await self.platform_mgr.shutdown()
|
|
||||||
|
|
||||||
self.platform_mgr = im_mgr.PlatformManager(self)
|
|
||||||
|
|
||||||
await self.platform_mgr.initialize()
|
|
||||||
|
|
||||||
self.task_mgr.create_task(
|
|
||||||
self.platform_mgr.run(),
|
|
||||||
name='platform-manager',
|
|
||||||
scopes=[
|
|
||||||
core_entities.LifecycleControlScope.APPLICATION,
|
|
||||||
core_entities.LifecycleControlScope.PLATFORM,
|
|
||||||
],
|
|
||||||
)
|
|
||||||
case core_entities.LifecycleControlScope.PLUGIN.value:
|
|
||||||
self.logger.info('执行热重载 scope=' + scope)
|
|
||||||
await self.plugin_mgr.destroy_plugins()
|
|
||||||
|
|
||||||
# 删除 sys.module 中所有的 plugins/* 下的模块
|
|
||||||
for mod in list(sys.modules.keys()):
|
|
||||||
if mod.startswith('plugins.'):
|
|
||||||
del sys.modules[mod]
|
|
||||||
|
|
||||||
self.plugin_mgr = plugin_mgr.PluginManager(self)
|
|
||||||
await self.plugin_mgr.initialize()
|
|
||||||
|
|
||||||
await self.plugin_mgr.initialize_plugins()
|
|
||||||
|
|
||||||
await self.plugin_mgr.load_plugins()
|
|
||||||
await self.plugin_mgr.initialize_plugins()
|
|
||||||
case core_entities.LifecycleControlScope.PROVIDER.value:
|
|
||||||
self.logger.info('执行热重载 scope=' + scope)
|
|
||||||
|
|
||||||
await self.tool_mgr.shutdown()
|
|
||||||
|
|
||||||
llm_model_mgr_inst = llm_model_mgr.ModelManager(self)
|
|
||||||
await llm_model_mgr_inst.initialize()
|
|
||||||
self.model_mgr = llm_model_mgr_inst
|
|
||||||
|
|
||||||
llm_session_mgr_inst = llm_session_mgr.SessionManager(self)
|
|
||||||
await llm_session_mgr_inst.initialize()
|
|
||||||
self.sess_mgr = llm_session_mgr_inst
|
|
||||||
|
|
||||||
llm_tool_mgr_inst = llm_tool_mgr.ToolManager(self)
|
|
||||||
await llm_tool_mgr_inst.initialize()
|
|
||||||
self.tool_mgr = llm_tool_mgr_inst
|
|
||||||
case _:
|
|
||||||
pass
|
|
||||||
|
|||||||
@@ -51,8 +51,8 @@ async def main(loop: asyncio.AbstractEventLoop):
|
|||||||
import signal
|
import signal
|
||||||
|
|
||||||
def signal_handler(sig, frame):
|
def signal_handler(sig, frame):
|
||||||
|
app_inst.dispose()
|
||||||
print('[Signal] 程序退出.')
|
print('[Signal] 程序退出.')
|
||||||
# ap.shutdown()
|
|
||||||
os._exit(0)
|
os._exit(0)
|
||||||
|
|
||||||
signal.signal(signal.SIGINT, signal_handler)
|
signal.signal(signal.SIGINT, signal_handler)
|
||||||
|
|||||||
@@ -1,18 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import enum
|
import enum
|
||||||
import typing
|
|
||||||
import datetime
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
import pydantic.v1 as pydantic
|
|
||||||
|
|
||||||
from ..provider import entities as llm_entities
|
|
||||||
from ..provider.modelmgr import requester
|
|
||||||
from ..provider.tools import entities as tools_entities
|
|
||||||
from ..platform import adapter as msadapter
|
|
||||||
from ..platform.types import message as platform_message
|
|
||||||
from ..platform.types import events as platform_events
|
|
||||||
|
|
||||||
|
|
||||||
class LifecycleControlScope(enum.Enum):
|
class LifecycleControlScope(enum.Enum):
|
||||||
@@ -20,157 +8,3 @@ class LifecycleControlScope(enum.Enum):
|
|||||||
PLATFORM = 'platform'
|
PLATFORM = 'platform'
|
||||||
PLUGIN = 'plugin'
|
PLUGIN = 'plugin'
|
||||||
PROVIDER = 'provider'
|
PROVIDER = 'provider'
|
||||||
|
|
||||||
|
|
||||||
class LauncherTypes(enum.Enum):
|
|
||||||
"""一个请求的发起者类型"""
|
|
||||||
|
|
||||||
PERSON = 'person'
|
|
||||||
"""私聊"""
|
|
||||||
|
|
||||||
GROUP = 'group'
|
|
||||||
"""群聊"""
|
|
||||||
|
|
||||||
|
|
||||||
class Query(pydantic.BaseModel):
|
|
||||||
"""一次请求的信息封装"""
|
|
||||||
|
|
||||||
query_id: int
|
|
||||||
"""请求ID,添加进请求池时生成"""
|
|
||||||
|
|
||||||
launcher_type: LauncherTypes
|
|
||||||
"""会话类型,platform处理阶段设置"""
|
|
||||||
|
|
||||||
launcher_id: typing.Union[int, str]
|
|
||||||
"""会话ID,platform处理阶段设置"""
|
|
||||||
|
|
||||||
sender_id: typing.Union[int, str]
|
|
||||||
"""发送者ID,platform处理阶段设置"""
|
|
||||||
|
|
||||||
message_event: platform_events.MessageEvent
|
|
||||||
"""事件,platform收到的原始事件"""
|
|
||||||
|
|
||||||
message_chain: platform_message.MessageChain
|
|
||||||
"""消息链,platform收到的原始消息链"""
|
|
||||||
|
|
||||||
bot_uuid: typing.Optional[str] = None
|
|
||||||
"""机器人UUID。"""
|
|
||||||
|
|
||||||
pipeline_uuid: typing.Optional[str] = None
|
|
||||||
"""流水线UUID。"""
|
|
||||||
|
|
||||||
pipeline_config: typing.Optional[dict[str, typing.Any]] = None
|
|
||||||
"""流水线配置,由 Pipeline 在运行开始时设置。"""
|
|
||||||
|
|
||||||
adapter: msadapter.MessagePlatformAdapter
|
|
||||||
"""消息平台适配器对象,单个app中可能启用了多个消息平台适配器,此对象表明发起此query的适配器"""
|
|
||||||
|
|
||||||
session: typing.Optional[Session] = None
|
|
||||||
"""会话对象,由前置处理器阶段设置"""
|
|
||||||
|
|
||||||
messages: typing.Optional[list[llm_entities.Message]] = []
|
|
||||||
"""历史消息列表,由前置处理器阶段设置"""
|
|
||||||
|
|
||||||
prompt: typing.Optional[llm_entities.Prompt] = None
|
|
||||||
"""情景预设内容,由前置处理器阶段设置"""
|
|
||||||
|
|
||||||
user_message: typing.Optional[llm_entities.Message] = None
|
|
||||||
"""此次请求的用户消息对象,由前置处理器阶段设置"""
|
|
||||||
|
|
||||||
variables: typing.Optional[dict[str, typing.Any]] = None
|
|
||||||
"""变量,由前置处理器阶段设置。在prompt中嵌入或由 Runner 传递到 LLMOps 平台。"""
|
|
||||||
|
|
||||||
use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None
|
|
||||||
"""使用的对话模型,由前置处理器阶段设置"""
|
|
||||||
|
|
||||||
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
|
|
||||||
"""使用的函数,由前置处理器阶段设置"""
|
|
||||||
|
|
||||||
resp_messages: (
|
|
||||||
typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]]
|
|
||||||
) = []
|
|
||||||
"""由Process阶段生成的回复消息对象列表"""
|
|
||||||
|
|
||||||
resp_message_chain: typing.Optional[list[platform_message.MessageChain]] = None
|
|
||||||
"""回复消息链,从resp_messages包装而得"""
|
|
||||||
|
|
||||||
# ======= 内部保留 =======
|
|
||||||
current_stage: typing.Optional['pkg.pipeline.pipelinemgr.StageInstContainer'] = None
|
|
||||||
"""当前所处阶段"""
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
# ========== 插件可调用的 API(请求 API) ==========
|
|
||||||
|
|
||||||
def set_variable(self, key: str, value: typing.Any):
|
|
||||||
"""设置变量"""
|
|
||||||
if self.variables is None:
|
|
||||||
self.variables = {}
|
|
||||||
self.variables[key] = value
|
|
||||||
|
|
||||||
def get_variable(self, key: str) -> typing.Any:
|
|
||||||
"""获取变量"""
|
|
||||||
if self.variables is None:
|
|
||||||
return None
|
|
||||||
return self.variables.get(key)
|
|
||||||
|
|
||||||
def get_variables(self) -> dict[str, typing.Any]:
|
|
||||||
"""获取所有变量"""
|
|
||||||
if self.variables is None:
|
|
||||||
return {}
|
|
||||||
return self.variables
|
|
||||||
|
|
||||||
|
|
||||||
class Conversation(pydantic.BaseModel):
|
|
||||||
"""对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation,但只有一个当前使用的 Conversation"""
|
|
||||||
|
|
||||||
prompt: llm_entities.Prompt
|
|
||||||
|
|
||||||
messages: list[llm_entities.Message]
|
|
||||||
|
|
||||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
|
||||||
|
|
||||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
|
||||||
|
|
||||||
use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None
|
|
||||||
|
|
||||||
use_funcs: typing.Optional[list[tools_entities.LLMFunction]]
|
|
||||||
|
|
||||||
pipeline_uuid: str
|
|
||||||
"""流水线UUID。"""
|
|
||||||
|
|
||||||
bot_uuid: str
|
|
||||||
"""机器人UUID。"""
|
|
||||||
|
|
||||||
uuid: typing.Optional[str] = None
|
|
||||||
"""该对话的 uuid,在创建时不会自动生成。而是当使用 Dify API 等由外部管理对话信息的服务时,用于绑定外部的会话。具体如何使用,取决于 Runner。"""
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
|
|
||||||
class Session(pydantic.BaseModel):
|
|
||||||
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""
|
|
||||||
|
|
||||||
launcher_type: LauncherTypes
|
|
||||||
|
|
||||||
launcher_id: typing.Union[int, str]
|
|
||||||
|
|
||||||
sender_id: typing.Optional[typing.Union[int, str]] = 0
|
|
||||||
|
|
||||||
use_prompt_name: typing.Optional[str] = 'default'
|
|
||||||
|
|
||||||
using_conversation: typing.Optional[Conversation] = None
|
|
||||||
|
|
||||||
conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list)
|
|
||||||
|
|
||||||
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
|
||||||
|
|
||||||
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
|
|
||||||
|
|
||||||
semaphore: typing.Optional[asyncio.Semaphore] = None
|
|
||||||
"""当前会话的信号量,用于限制并发"""
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|||||||
@@ -1,10 +1,11 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
from .. import stage, app
|
from .. import stage, app
|
||||||
from ...utils import version, proxy, announce
|
from ...utils import version, proxy, announce
|
||||||
from ...pipeline import pool, controller, pipelinemgr
|
from ...pipeline import pool, controller, pipelinemgr
|
||||||
from ...plugin import manager as plugin_mgr
|
from ...plugin import connector as plugin_connector
|
||||||
from ...command import cmdmgr
|
from ...command import cmdmgr
|
||||||
from ...provider.session import sessionmgr as llm_session_mgr
|
from ...provider.session import sessionmgr as llm_session_mgr
|
||||||
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
||||||
@@ -59,10 +60,13 @@ class BuildAppStage(stage.BootingStage):
|
|||||||
ap.persistence_mgr = persistence_mgr_inst
|
ap.persistence_mgr = persistence_mgr_inst
|
||||||
await persistence_mgr_inst.initialize()
|
await persistence_mgr_inst.initialize()
|
||||||
|
|
||||||
plugin_mgr_inst = plugin_mgr.PluginManager(ap)
|
async def runtime_disconnect_callback(connector: plugin_connector.PluginRuntimeConnector) -> None:
|
||||||
await plugin_mgr_inst.initialize()
|
await asyncio.sleep(3)
|
||||||
ap.plugin_mgr = plugin_mgr_inst
|
await plugin_connector_inst.initialize()
|
||||||
await plugin_mgr_inst.load_plugins()
|
|
||||||
|
plugin_connector_inst = plugin_connector.PluginRuntimeConnector(ap, runtime_disconnect_callback)
|
||||||
|
await plugin_connector_inst.initialize()
|
||||||
|
ap.plugin_connector = plugin_connector_inst
|
||||||
|
|
||||||
cmd_mgr_inst = cmdmgr.CommandManager(ap)
|
cmd_mgr_inst = cmdmgr.CommandManager(ap)
|
||||||
await cmd_mgr_inst.initialize()
|
await cmd_mgr_inst.initialize()
|
||||||
|
|||||||
22
pkg/entity/persistence/bstorage.py
Normal file
22
pkg/entity/persistence/bstorage.py
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import sqlalchemy
|
||||||
|
|
||||||
|
from .base import Base
|
||||||
|
|
||||||
|
|
||||||
|
class BinaryStorage(Base):
|
||||||
|
"""Current for plugin use only"""
|
||||||
|
|
||||||
|
__tablename__ = 'binary_storages'
|
||||||
|
|
||||||
|
unique_key = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True)
|
||||||
|
key = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||||
|
owner_type = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||||
|
owner = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||||
|
value = sqlalchemy.Column(sqlalchemy.LargeBinary, nullable=False)
|
||||||
|
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||||
|
updated_at = sqlalchemy.Column(
|
||||||
|
sqlalchemy.DateTime,
|
||||||
|
nullable=False,
|
||||||
|
server_default=sqlalchemy.func.now(),
|
||||||
|
onupdate=sqlalchemy.func.now(),
|
||||||
|
)
|
||||||
@@ -13,6 +13,8 @@ class PluginSetting(Base):
|
|||||||
enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True)
|
enabled = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=True)
|
||||||
priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
|
priority = sqlalchemy.Column(sqlalchemy.Integer, nullable=False, default=0)
|
||||||
config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict)
|
config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict)
|
||||||
|
install_source = sqlalchemy.Column(sqlalchemy.String(255), nullable=False, default='github')
|
||||||
|
install_info = sqlalchemy.Column(sqlalchemy.JSON, nullable=False, default=dict)
|
||||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
created_at = sqlalchemy.Column(sqlalchemy.DateTime, nullable=False, server_default=sqlalchemy.func.now())
|
||||||
updated_at = sqlalchemy.Column(
|
updated_at = sqlalchemy.Column(
|
||||||
sqlalchemy.DateTime,
|
sqlalchemy.DateTime,
|
||||||
|
|||||||
@@ -44,6 +44,38 @@ class PersistenceManager:
|
|||||||
|
|
||||||
await self.create_tables()
|
await self.create_tables()
|
||||||
|
|
||||||
|
# run migrations
|
||||||
|
database_version = await self.execute_async(
|
||||||
|
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
|
||||||
|
)
|
||||||
|
|
||||||
|
database_version = int(database_version.fetchone()[1])
|
||||||
|
required_database_version = constants.required_database_version
|
||||||
|
|
||||||
|
if database_version < required_database_version:
|
||||||
|
migrations = migration.preregistered_db_migrations
|
||||||
|
migrations.sort(key=lambda x: x.number)
|
||||||
|
|
||||||
|
last_migration_number = database_version
|
||||||
|
|
||||||
|
for migration_cls in migrations:
|
||||||
|
migration_instance = migration_cls(self.ap)
|
||||||
|
|
||||||
|
if (
|
||||||
|
migration_instance.number > database_version
|
||||||
|
and migration_instance.number <= required_database_version
|
||||||
|
):
|
||||||
|
await migration_instance.upgrade()
|
||||||
|
await self.execute_async(
|
||||||
|
sqlalchemy.update(metadata.Metadata)
|
||||||
|
.where(metadata.Metadata.key == 'database_version')
|
||||||
|
.values({'value': str(migration_instance.number)})
|
||||||
|
)
|
||||||
|
last_migration_number = migration_instance.number
|
||||||
|
self.ap.logger.info(f'Migration {migration_instance.number} completed.')
|
||||||
|
|
||||||
|
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
|
||||||
|
|
||||||
async def create_tables(self):
|
async def create_tables(self):
|
||||||
# create tables
|
# create tables
|
||||||
async with self.get_db_engine().connect() as conn:
|
async with self.get_db_engine().connect() as conn:
|
||||||
@@ -87,38 +119,6 @@ class PersistenceManager:
|
|||||||
|
|
||||||
# =================================
|
# =================================
|
||||||
|
|
||||||
# run migrations
|
|
||||||
database_version = await self.execute_async(
|
|
||||||
sqlalchemy.select(metadata.Metadata).where(metadata.Metadata.key == 'database_version')
|
|
||||||
)
|
|
||||||
|
|
||||||
database_version = int(database_version.fetchone()[1])
|
|
||||||
required_database_version = constants.required_database_version
|
|
||||||
|
|
||||||
if database_version < required_database_version:
|
|
||||||
migrations = migration.preregistered_db_migrations
|
|
||||||
migrations.sort(key=lambda x: x.number)
|
|
||||||
|
|
||||||
last_migration_number = database_version
|
|
||||||
|
|
||||||
for migration_cls in migrations:
|
|
||||||
migration_instance = migration_cls(self.ap)
|
|
||||||
|
|
||||||
if (
|
|
||||||
migration_instance.number > database_version
|
|
||||||
and migration_instance.number <= required_database_version
|
|
||||||
):
|
|
||||||
await migration_instance.upgrade()
|
|
||||||
await self.execute_async(
|
|
||||||
sqlalchemy.update(metadata.Metadata)
|
|
||||||
.where(metadata.Metadata.key == 'database_version')
|
|
||||||
.values({'value': str(migration_instance.number)})
|
|
||||||
)
|
|
||||||
last_migration_number = migration_instance.number
|
|
||||||
self.ap.logger.info(f'Migration {migration_instance.number} completed.')
|
|
||||||
|
|
||||||
self.ap.logger.info(f'Successfully upgraded database to version {last_migration_number}.')
|
|
||||||
|
|
||||||
async def execute_async(self, *args, **kwargs) -> sqlalchemy.engine.cursor.CursorResult:
|
async def execute_async(self, *args, **kwargs) -> sqlalchemy.engine.cursor.CursorResult:
|
||||||
async with self.get_db_engine().connect() as conn:
|
async with self.get_db_engine().connect() as conn:
|
||||||
result = await conn.execute(*args, **kwargs)
|
result = await conn.execute(*args, **kwargs)
|
||||||
@@ -128,10 +128,13 @@ class PersistenceManager:
|
|||||||
def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
|
def get_db_engine(self) -> sqlalchemy_asyncio.AsyncEngine:
|
||||||
return self.db.get_engine()
|
return self.db.get_engine()
|
||||||
|
|
||||||
def serialize_model(self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base) -> dict:
|
def serialize_model(
|
||||||
|
self, model: typing.Type[sqlalchemy.Base], data: sqlalchemy.Base, masked_columns: list[str] = []
|
||||||
|
) -> dict:
|
||||||
return {
|
return {
|
||||||
column.name: getattr(data, column.name)
|
column.name: getattr(data, column.name)
|
||||||
if not isinstance(getattr(data, column.name), (datetime.datetime))
|
if not isinstance(getattr(data, column.name), (datetime.datetime))
|
||||||
else getattr(data, column.name).isoformat()
|
else getattr(data, column.name).isoformat()
|
||||||
for column in model.__table__.columns
|
for column in model.__table__.columns
|
||||||
|
if column.name not in masked_columns
|
||||||
}
|
}
|
||||||
|
|||||||
20
pkg/persistence/migrations/dbm004_plugin_config.py
Normal file
20
pkg/persistence/migrations/dbm004_plugin_config.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
|
@migration.migration_class(4)
|
||||||
|
class DBMigratePluginConfig(migration.DBMigration):
|
||||||
|
"""插件配置"""
|
||||||
|
|
||||||
|
async def upgrade(self):
|
||||||
|
"""升级"""
|
||||||
|
|
||||||
|
if 'plugin' not in self.ap.instance_config.data:
|
||||||
|
self.ap.instance_config.data['plugin'] = {
|
||||||
|
'runtime_ws_url': 'ws://localhost:5400/control/ws',
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.ap.instance_config.dump_config()
|
||||||
|
|
||||||
|
async def downgrade(self):
|
||||||
|
"""降级"""
|
||||||
|
pass
|
||||||
32
pkg/persistence/migrations/dbm005_plugin_install_source.py
Normal file
32
pkg/persistence/migrations/dbm005_plugin_install_source.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import sqlalchemy
|
||||||
|
from .. import migration
|
||||||
|
|
||||||
|
|
||||||
|
@migration.migration_class(5)
|
||||||
|
class DBMigratePluginInstallSource(migration.DBMigration):
|
||||||
|
"""插件安装来源"""
|
||||||
|
|
||||||
|
async def upgrade(self):
|
||||||
|
"""升级"""
|
||||||
|
# 查询表结构获取所有列名(异步执行 SQL)
|
||||||
|
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.text('PRAGMA table_info(plugin_settings);'))
|
||||||
|
# fetchall() 是同步方法,无需 await
|
||||||
|
columns = [row[1] for row in result.fetchall()]
|
||||||
|
|
||||||
|
# 检查并添加 install_source 列
|
||||||
|
if 'install_source' not in columns:
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text(
|
||||||
|
"ALTER TABLE plugin_settings ADD COLUMN install_source VARCHAR(255) NOT NULL DEFAULT 'github'"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# 检查并添加 install_info 列
|
||||||
|
if 'install_info' not in columns:
|
||||||
|
await self.ap.persistence_mgr.execute_async(
|
||||||
|
sqlalchemy.text("ALTER TABLE plugin_settings ADD COLUMN install_info JSON NOT NULL DEFAULT '{}'")
|
||||||
|
)
|
||||||
|
|
||||||
|
async def downgrade(self):
|
||||||
|
"""降级"""
|
||||||
|
pass
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .. import stage, entities
|
from .. import stage, entities
|
||||||
from ...core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@stage.stage_class('BanSessionCheckStage')
|
@stage.stage_class('BanSessionCheckStage')
|
||||||
@@ -14,7 +14,7 @@ class BanSessionCheckStage(stage.PipelineStage):
|
|||||||
async def initialize(self, pipeline_config: dict):
|
async def initialize(self, pipeline_config: dict):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||||
found = False
|
found = False
|
||||||
|
|
||||||
mode = query.pipeline_config['trigger']['access-control']['mode']
|
mode = query.pipeline_config['trigger']['access-control']['mode']
|
||||||
|
|||||||
@@ -3,12 +3,11 @@ from __future__ import annotations
|
|||||||
from ...core import app
|
from ...core import app
|
||||||
|
|
||||||
from .. import stage, entities
|
from .. import stage, entities
|
||||||
from ...core import entities as core_entities
|
|
||||||
from . import filter as filter_model, entities as filter_entities
|
from . import filter as filter_model, entities as filter_entities
|
||||||
from ...provider import entities as llm_entities
|
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||||
from ...platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
from ...utils import importutil
|
from ...utils import importutil
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
from . import filters
|
from . import filters
|
||||||
|
|
||||||
importutil.import_modules_in_pkg(filters)
|
importutil.import_modules_in_pkg(filters)
|
||||||
@@ -58,7 +57,7 @@ class ContentFilterStage(stage.PipelineStage):
|
|||||||
async def _pre_process(
|
async def _pre_process(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> entities.StageProcessResult:
|
) -> entities.StageProcessResult:
|
||||||
"""请求llm前处理消息
|
"""请求llm前处理消息
|
||||||
只要有一个不通过就不放行,只放行 PASS 的消息
|
只要有一个不通过就不放行,只放行 PASS 的消息
|
||||||
@@ -67,7 +66,7 @@ class ContentFilterStage(stage.PipelineStage):
|
|||||||
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
|
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
|
||||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||||
if not message.strip():
|
if not message.strip():
|
||||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||||
else:
|
else:
|
||||||
for filter in self.filter_chain:
|
for filter in self.filter_chain:
|
||||||
if filter_entities.EnableStage.PRE in filter.enable_stages:
|
if filter_entities.EnableStage.PRE in filter.enable_stages:
|
||||||
@@ -86,14 +85,14 @@ class ContentFilterStage(stage.PipelineStage):
|
|||||||
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
|
elif result.level == filter_entities.ResultLevel.PASS: # 传到下一个
|
||||||
message = result.replacement
|
message = result.replacement
|
||||||
|
|
||||||
query.message_chain = platform_message.MessageChain(platform_message.Plain(message))
|
query.message_chain = platform_message.MessageChain([platform_message.Plain(text=message)])
|
||||||
|
|
||||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||||
|
|
||||||
async def _post_process(
|
async def _post_process(
|
||||||
self,
|
self,
|
||||||
message: str,
|
message: str,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> entities.StageProcessResult:
|
) -> entities.StageProcessResult:
|
||||||
"""请求llm后处理响应
|
"""请求llm后处理响应
|
||||||
只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter
|
只要是 PASS 或者 MASKED 的就通过此 filter,将其 replacement 设置为message,进入下一个 filter
|
||||||
@@ -123,7 +122,7 @@ class ContentFilterStage(stage.PipelineStage):
|
|||||||
|
|
||||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||||
"""处理"""
|
"""处理"""
|
||||||
if stage_inst_name == 'PreContentFilterStage':
|
if stage_inst_name == 'PreContentFilterStage':
|
||||||
contain_non_text = False
|
contain_non_text = False
|
||||||
@@ -142,7 +141,7 @@ class ContentFilterStage(stage.PipelineStage):
|
|||||||
return await self._pre_process(str(query.message_chain).strip(), query)
|
return await self._pre_process(str(query.message_chain).strip(), query)
|
||||||
elif stage_inst_name == 'PostContentFilterStage':
|
elif stage_inst_name == 'PostContentFilterStage':
|
||||||
# 仅处理 query.resp_messages[-1].content 是 str 的情况
|
# 仅处理 query.resp_messages[-1].content 是 str 的情况
|
||||||
if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(
|
if isinstance(query.resp_messages[-1], provider_message.Message) and isinstance(
|
||||||
query.resp_messages[-1].content, str
|
query.resp_messages[-1].content, str
|
||||||
):
|
):
|
||||||
return await self._post_process(query.resp_messages[-1].content, query)
|
return await self._post_process(query.resp_messages[-1].content, query)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import enum
|
import enum
|
||||||
|
|
||||||
import pydantic.v1 as pydantic
|
import pydantic
|
||||||
|
|
||||||
|
|
||||||
class ResultLevel(enum.Enum):
|
class ResultLevel(enum.Enum):
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ from __future__ import annotations
|
|||||||
import abc
|
import abc
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ...core import app, entities as core_entities
|
from ...core import app
|
||||||
from . import entities
|
from . import entities
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
preregistered_filters: list[typing.Type[ContentFilter]] = []
|
preregistered_filters: list[typing.Type[ContentFilter]] = []
|
||||||
|
|
||||||
@@ -60,7 +60,7 @@ class ContentFilter(metaclass=abc.ABCMeta):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def process(self, query: core_entities.Query, message: str = None, image_url=None) -> entities.FilterResult:
|
async def process(self, query: pipeline_query.Query, message: str = None, image_url=None) -> entities.FilterResult:
|
||||||
"""处理消息
|
"""处理消息
|
||||||
|
|
||||||
分为前后阶段,具体取决于 enable_stages 的值。
|
分为前后阶段,具体取决于 enable_stages 的值。
|
||||||
|
|||||||
@@ -4,8 +4,7 @@ import aiohttp
|
|||||||
|
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from .. import filter as filter_model
|
from .. import filter as filter_model
|
||||||
from ....core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}'
|
BAIDU_EXAMINE_URL = 'https://aip.baidubce.com/rest/2.0/solution/v1/text_censor/v2/user_defined?access_token={}'
|
||||||
BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
|
BAIDU_EXAMINE_TOKEN_URL = 'https://aip.baidubce.com/oauth/2.0/token'
|
||||||
@@ -27,7 +26,7 @@ class BaiduCloudExamine(filter_model.ContentFilter):
|
|||||||
) as resp:
|
) as resp:
|
||||||
return (await resp.json())['access_token']
|
return (await resp.json())['access_token']
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.post(
|
async with session.post(
|
||||||
BAIDU_EXAMINE_URL.format(await self._get_token()),
|
BAIDU_EXAMINE_URL.format(await self._get_token()),
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import re
|
|||||||
|
|
||||||
from .. import filter as filter_model
|
from .. import filter as filter_model
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from ....core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@filter_model.filter_class('ban-word-filter')
|
@filter_model.filter_class('ban-word-filter')
|
||||||
@@ -13,7 +13,7 @@ class BanWordFilter(filter_model.ContentFilter):
|
|||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
|
||||||
found = False
|
found = False
|
||||||
|
|
||||||
for word in self.ap.sensitive_meta.data['words']:
|
for word in self.ap.sensitive_meta.data['words']:
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import re
|
|||||||
|
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from .. import filter as filter_model
|
from .. import filter as filter_model
|
||||||
from ....core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@filter_model.filter_class('content-ignore')
|
@filter_model.filter_class('content-ignore')
|
||||||
@@ -16,7 +16,7 @@ class ContentIgnore(filter_model.ContentFilter):
|
|||||||
entities.EnableStage.PRE,
|
entities.EnableStage.PRE,
|
||||||
]
|
]
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, message: str) -> entities.FilterResult:
|
async def process(self, query: pipeline_query.Query, message: str) -> entities.FilterResult:
|
||||||
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
|
if 'prefix' in query.pipeline_config['trigger']['ignore-rules']:
|
||||||
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
|
for rule in query.pipeline_config['trigger']['ignore-rules']['prefix']:
|
||||||
if message.startswith(rule):
|
if message.startswith(rule):
|
||||||
|
|||||||
@@ -3,7 +3,10 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
from ..core import app, entities
|
from ..core import app
|
||||||
|
from ..core import entities as core_entities
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class Controller:
|
class Controller:
|
||||||
@@ -22,19 +25,19 @@ class Controller:
|
|||||||
"""事件处理循环"""
|
"""事件处理循环"""
|
||||||
try:
|
try:
|
||||||
while True:
|
while True:
|
||||||
selected_query: entities.Query = None
|
selected_query: pipeline_query.Query = None
|
||||||
|
|
||||||
# 取请求
|
# 取请求
|
||||||
async with self.ap.query_pool:
|
async with self.ap.query_pool:
|
||||||
queries: list[entities.Query] = self.ap.query_pool.queries
|
queries: list[pipeline_query.Query] = self.ap.query_pool.queries
|
||||||
|
|
||||||
for query in queries:
|
for query in queries:
|
||||||
session = await self.ap.sess_mgr.get_session(query)
|
session = await self.ap.sess_mgr.get_session(query)
|
||||||
self.ap.logger.debug(f'Checking query {query} session {session}')
|
self.ap.logger.debug(f'Checking query {query} session {session}')
|
||||||
|
|
||||||
if not session.semaphore.locked():
|
if not session._semaphore.locked():
|
||||||
selected_query = query
|
selected_query = query
|
||||||
await session.semaphore.acquire()
|
await session._semaphore.acquire()
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
@@ -46,7 +49,7 @@ class Controller:
|
|||||||
|
|
||||||
if selected_query:
|
if selected_query:
|
||||||
|
|
||||||
async def _process_query(selected_query: entities.Query):
|
async def _process_query(selected_query: pipeline_query.Query):
|
||||||
async with self.semaphore: # 总并发上限
|
async with self.semaphore: # 总并发上限
|
||||||
# find pipeline
|
# find pipeline
|
||||||
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
|
# Here firstly find the bot, then find the pipeline, in case the bot adapter's config is not the latest one.
|
||||||
@@ -59,7 +62,7 @@ class Controller:
|
|||||||
await pipeline.run(selected_query)
|
await pipeline.run(selected_query)
|
||||||
|
|
||||||
async with self.ap.query_pool:
|
async with self.ap.query_pool:
|
||||||
(await self.ap.sess_mgr.get_session(selected_query)).semaphore.release()
|
(await self.ap.sess_mgr.get_session(selected_query))._semaphore.release()
|
||||||
# 通知其他协程,有新的请求可以处理了
|
# 通知其他协程,有新的请求可以处理了
|
||||||
self.ap.query_pool.condition.notify_all()
|
self.ap.query_pool.condition.notify_all()
|
||||||
|
|
||||||
@@ -68,8 +71,8 @@ class Controller:
|
|||||||
kind='query',
|
kind='query',
|
||||||
name=f'query-{selected_query.query_id}',
|
name=f'query-{selected_query.query_id}',
|
||||||
scopes=[
|
scopes=[
|
||||||
entities.LifecycleControlScope.APPLICATION,
|
core_entities.LifecycleControlScope.APPLICATION,
|
||||||
entities.LifecycleControlScope.PLATFORM,
|
core_entities.LifecycleControlScope.PLATFORM,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ from __future__ import annotations
|
|||||||
import enum
|
import enum
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
import pydantic.v1 as pydantic
|
import pydantic
|
||||||
from ..platform.types import message as platform_message
|
|
||||||
|
|
||||||
from ..core import entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
|
|
||||||
|
|
||||||
class ResultType(enum.Enum):
|
class ResultType(enum.Enum):
|
||||||
@@ -20,7 +20,7 @@ class ResultType(enum.Enum):
|
|||||||
class StageProcessResult(pydantic.BaseModel):
|
class StageProcessResult(pydantic.BaseModel):
|
||||||
result_type: ResultType
|
result_type: ResultType
|
||||||
|
|
||||||
new_query: entities.Query
|
new_query: pipeline_query.Query
|
||||||
|
|
||||||
user_notice: typing.Optional[
|
user_notice: typing.Optional[
|
||||||
typing.Union[
|
typing.Union[
|
||||||
|
|||||||
@@ -5,10 +5,9 @@ import traceback
|
|||||||
|
|
||||||
from . import strategy
|
from . import strategy
|
||||||
from .. import stage, entities
|
from .. import stage, entities
|
||||||
from ...core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
from ...platform.types import message as platform_message
|
|
||||||
from ...utils import importutil
|
from ...utils import importutil
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
from . import strategies
|
from . import strategies
|
||||||
|
|
||||||
importutil.import_modules_in_pkg(strategies)
|
importutil.import_modules_in_pkg(strategies)
|
||||||
@@ -67,7 +66,7 @@ class LongTextProcessStage(stage.PipelineStage):
|
|||||||
|
|
||||||
await self.strategy_impl.initialize()
|
await self.strategy_impl.initialize()
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||||
# 检查是否包含非 Plain 组件
|
# 检查是否包含非 Plain 组件
|
||||||
contains_non_plain = False
|
contains_non_plain = False
|
||||||
|
|
||||||
|
|||||||
@@ -3,9 +3,9 @@ from __future__ import annotations
|
|||||||
|
|
||||||
|
|
||||||
from .. import strategy as strategy_model
|
from .. import strategy as strategy_model
|
||||||
from ....core import entities as core_entities
|
|
||||||
from ....platform.types import message as platform_message
|
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
|
|
||||||
ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay
|
ForwardMessageDiaplay = platform_message.ForwardMessageDiaplay
|
||||||
Forward = platform_message.Forward
|
Forward = platform_message.Forward
|
||||||
@@ -13,7 +13,7 @@ Forward = platform_message.Forward
|
|||||||
|
|
||||||
@strategy_model.strategy_class('forward')
|
@strategy_model.strategy_class('forward')
|
||||||
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
|
class ForwardComponentStrategy(strategy_model.LongTextStrategy):
|
||||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
|
||||||
display = ForwardMessageDiaplay(
|
display = ForwardMessageDiaplay(
|
||||||
title='群聊的聊天记录',
|
title='群聊的聊天记录',
|
||||||
brief='[聊天记录]',
|
brief='[聊天记录]',
|
||||||
|
|||||||
@@ -8,10 +8,10 @@ import re
|
|||||||
from PIL import Image, ImageDraw, ImageFont
|
from PIL import Image, ImageDraw, ImageFont
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from ....platform.types import message as platform_message
|
|
||||||
|
|
||||||
from .. import strategy as strategy_model
|
from .. import strategy as strategy_model
|
||||||
from ....core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
|
|
||||||
|
|
||||||
@strategy_model.strategy_class('image')
|
@strategy_model.strategy_class('image')
|
||||||
@@ -27,7 +27,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
|||||||
encoding='utf-8',
|
encoding='utf-8',
|
||||||
)
|
)
|
||||||
|
|
||||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
|
||||||
img_path = self.text_to_image(
|
img_path = self.text_to_image(
|
||||||
text_str=message,
|
text_str=message,
|
||||||
save_as='temp/{}.png'.format(int(time.time())),
|
save_as='temp/{}.png'.format(int(time.time())),
|
||||||
@@ -131,7 +131,7 @@ class Text2ImageStrategy(strategy_model.LongTextStrategy):
|
|||||||
text_str: str,
|
text_str: str,
|
||||||
save_as='temp.png',
|
save_as='temp.png',
|
||||||
width=800,
|
width=800,
|
||||||
query: core_entities.Query = None,
|
query: pipeline_query.Query = None,
|
||||||
):
|
):
|
||||||
text_str = text_str.replace('\t', ' ')
|
text_str = text_str.replace('\t', ' ')
|
||||||
|
|
||||||
|
|||||||
@@ -4,8 +4,9 @@ import typing
|
|||||||
|
|
||||||
|
|
||||||
from ...core import app
|
from ...core import app
|
||||||
from ...core import entities as core_entities
|
|
||||||
from ...platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
|
preregistered_strategies: list[typing.Type[LongTextStrategy]] = []
|
||||||
@@ -49,7 +50,7 @@ class LongTextStrategy(metaclass=abc.ABCMeta):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def process(self, message: str, query: core_entities.Query) -> list[platform_message.MessageComponent]:
|
async def process(self, message: str, query: pipeline_query.Query) -> list[platform_message.MessageComponent]:
|
||||||
"""处理长文本
|
"""处理长文本
|
||||||
|
|
||||||
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法
|
在 platform.json 中配置 long-text-process 字段,只要 文本长度超过了 threshold 就会调用此方法
|
||||||
|
|||||||
@@ -1,10 +1,9 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .. import stage, entities
|
from .. import stage, entities
|
||||||
from ...core import entities as core_entities
|
|
||||||
from . import truncator
|
from . import truncator
|
||||||
from ...utils import importutil
|
from ...utils import importutil
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
from . import truncators
|
from . import truncators
|
||||||
|
|
||||||
importutil.import_modules_in_pkg(truncators)
|
importutil.import_modules_in_pkg(truncators)
|
||||||
@@ -29,7 +28,7 @@ class ConversationMessageTruncator(stage.PipelineStage):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(f'未知的截断器: {use_method}')
|
raise ValueError(f'未知的截断器: {use_method}')
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||||
"""处理"""
|
"""处理"""
|
||||||
query = await self.trun.truncate(query)
|
query = await self.trun.truncate(query)
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
|||||||
import typing
|
import typing
|
||||||
import abc
|
import abc
|
||||||
|
|
||||||
from ...core import entities as core_entities, app
|
from ...core import app
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
preregistered_truncators: list[typing.Type[Truncator]] = []
|
preregistered_truncators: list[typing.Type[Truncator]] = []
|
||||||
|
|
||||||
@@ -47,7 +47,7 @@ class Truncator(abc.ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
|
async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
|
||||||
"""截断
|
"""截断
|
||||||
|
|
||||||
一般只需要操作query.messages,也可以扩展操作query.prompt, query.user_message。
|
一般只需要操作query.messages,也可以扩展操作query.prompt, query.user_message。
|
||||||
|
|||||||
@@ -1,14 +1,14 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from .. import truncator
|
from .. import truncator
|
||||||
from ....core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@truncator.truncator_class('round')
|
@truncator.truncator_class('round')
|
||||||
class RoundTruncator(truncator.Truncator):
|
class RoundTruncator(truncator.Truncator):
|
||||||
"""前文回合数阶段器"""
|
"""前文回合数阶段器"""
|
||||||
|
|
||||||
async def truncate(self, query: core_entities.Query) -> core_entities.Query:
|
async def truncate(self, query: pipeline_query.Query) -> pipeline_query.Query:
|
||||||
"""截断"""
|
"""截断"""
|
||||||
max_round = query.pipeline_config['ai']['local-agent']['max-round']
|
max_round = query.pipeline_config['ai']['local-agent']['max-round']
|
||||||
|
|
||||||
|
|||||||
@@ -5,14 +5,18 @@ import traceback
|
|||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
from ..core import app, entities
|
from ..core import app
|
||||||
from . import entities as pipeline_entities
|
from . import entities as pipeline_entities
|
||||||
from ..entity.persistence import pipeline as persistence_pipeline
|
from ..entity.persistence import pipeline as persistence_pipeline
|
||||||
from . import stage
|
from . import stage
|
||||||
from ..platform.types import message as platform_message, events as platform_events
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
from ..plugin import events
|
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||||
|
import langbot_plugin.api.entities.events as events
|
||||||
from ..utils import importutil
|
from ..utils import importutil
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
from . import (
|
from . import (
|
||||||
resprule,
|
resprule,
|
||||||
bansess,
|
bansess,
|
||||||
@@ -75,17 +79,17 @@ class RuntimePipeline:
|
|||||||
self.pipeline_entity = pipeline_entity
|
self.pipeline_entity = pipeline_entity
|
||||||
self.stage_containers = stage_containers
|
self.stage_containers = stage_containers
|
||||||
|
|
||||||
async def run(self, query: entities.Query):
|
async def run(self, query: pipeline_query.Query):
|
||||||
query.pipeline_config = self.pipeline_entity.config
|
query.pipeline_config = self.pipeline_entity.config
|
||||||
await self.process_query(query)
|
await self.process_query(query)
|
||||||
|
|
||||||
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
|
async def _check_output(self, query: pipeline_query.Query, result: pipeline_entities.StageProcessResult):
|
||||||
"""检查输出"""
|
"""检查输出"""
|
||||||
if result.user_notice:
|
if result.user_notice:
|
||||||
# 处理str类型
|
# 处理str类型
|
||||||
|
|
||||||
if isinstance(result.user_notice, str):
|
if isinstance(result.user_notice, str):
|
||||||
result.user_notice = platform_message.MessageChain(platform_message.Plain(result.user_notice))
|
result.user_notice = platform_message.MessageChain([platform_message.Plain(text=result.user_notice)])
|
||||||
elif isinstance(result.user_notice, list):
|
elif isinstance(result.user_notice, list):
|
||||||
result.user_notice = platform_message.MessageChain(*result.user_notice)
|
result.user_notice = platform_message.MessageChain(*result.user_notice)
|
||||||
|
|
||||||
@@ -109,7 +113,7 @@ class RuntimePipeline:
|
|||||||
async def _execute_from_stage(
|
async def _execute_from_stage(
|
||||||
self,
|
self,
|
||||||
stage_index: int,
|
stage_index: int,
|
||||||
query: entities.Query,
|
query: pipeline_query.Query,
|
||||||
):
|
):
|
||||||
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
|
"""从指定阶段开始执行,实现了责任链模式和基于生成器的阶段分叉功能。
|
||||||
|
|
||||||
@@ -136,7 +140,7 @@ class RuntimePipeline:
|
|||||||
while i < len(self.stage_containers):
|
while i < len(self.stage_containers):
|
||||||
stage_container = self.stage_containers[i]
|
stage_container = self.stage_containers[i]
|
||||||
|
|
||||||
query.current_stage = stage_container # 标记到 Query 对象里
|
query.current_stage_name = stage_container.inst_name # 标记到 Query 对象里
|
||||||
|
|
||||||
result = stage_container.inst.process(query, stage_container.inst_name)
|
result = stage_container.inst.process(query, stage_container.inst_name)
|
||||||
|
|
||||||
@@ -169,26 +173,26 @@ class RuntimePipeline:
|
|||||||
|
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
async def process_query(self, query: entities.Query):
|
async def process_query(self, query: pipeline_query.Query):
|
||||||
"""处理请求"""
|
"""处理请求"""
|
||||||
try:
|
try:
|
||||||
# ======== 触发 MessageReceived 事件 ========
|
# ======== 触发 MessageReceived 事件 ========
|
||||||
event_type = (
|
event_type = (
|
||||||
events.PersonMessageReceived
|
events.PersonMessageReceived
|
||||||
if query.launcher_type == entities.LauncherTypes.PERSON
|
if query.launcher_type == provider_session.LauncherTypes.PERSON
|
||||||
else events.GroupMessageReceived
|
else events.GroupMessageReceived
|
||||||
)
|
)
|
||||||
|
|
||||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
event_obj = event_type(
|
||||||
event=event_type(
|
query=query,
|
||||||
launcher_type=query.launcher_type.value,
|
launcher_type=query.launcher_type.value,
|
||||||
launcher_id=query.launcher_id,
|
launcher_id=query.launcher_id,
|
||||||
sender_id=query.sender_id,
|
sender_id=query.sender_id,
|
||||||
message_chain=query.message_chain,
|
message_chain=query.message_chain,
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
event_ctx = await self.ap.plugin_connector.emit_event(event_obj)
|
||||||
|
|
||||||
if event_ctx.is_prevented_default():
|
if event_ctx.is_prevented_default():
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -196,11 +200,12 @@ class RuntimePipeline:
|
|||||||
|
|
||||||
await self._execute_from_stage(0, query)
|
await self._execute_from_stage(0, query)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
|
inst_name = query.current_stage_name if query.current_stage_name else 'unknown'
|
||||||
self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}')
|
self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}')
|
||||||
self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
|
self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
|
||||||
finally:
|
finally:
|
||||||
self.ap.logger.debug(f'Query {query} processed')
|
self.ap.logger.debug(f'Query {query} processed')
|
||||||
|
del self.ap.query_pool.cached_queries[query.query_id]
|
||||||
|
|
||||||
|
|
||||||
class PipelineManager:
|
class PipelineManager:
|
||||||
|
|||||||
@@ -3,10 +3,11 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ..core import entities
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
from ..platform import adapter as msadapter
|
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||||
from ..platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||||
from ..platform.types import events as platform_events
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||||
|
|
||||||
|
|
||||||
class QueryPool:
|
class QueryPool:
|
||||||
@@ -16,7 +17,10 @@ class QueryPool:
|
|||||||
|
|
||||||
pool_lock: asyncio.Lock
|
pool_lock: asyncio.Lock
|
||||||
|
|
||||||
queries: list[entities.Query]
|
queries: list[pipeline_query.Query]
|
||||||
|
|
||||||
|
cached_queries: dict[int, pipeline_query.Query]
|
||||||
|
"""Cached queries, used for plugin backward api call, will be removed after the query completely processed"""
|
||||||
|
|
||||||
condition: asyncio.Condition
|
condition: asyncio.Condition
|
||||||
|
|
||||||
@@ -24,34 +28,38 @@ class QueryPool:
|
|||||||
self.query_id_counter = 0
|
self.query_id_counter = 0
|
||||||
self.pool_lock = asyncio.Lock()
|
self.pool_lock = asyncio.Lock()
|
||||||
self.queries = []
|
self.queries = []
|
||||||
|
self.cached_queries = {}
|
||||||
self.condition = asyncio.Condition(self.pool_lock)
|
self.condition = asyncio.Condition(self.pool_lock)
|
||||||
|
|
||||||
async def add_query(
|
async def add_query(
|
||||||
self,
|
self,
|
||||||
bot_uuid: str,
|
bot_uuid: str,
|
||||||
launcher_type: entities.LauncherTypes,
|
launcher_type: provider_session.LauncherTypes,
|
||||||
launcher_id: typing.Union[int, str],
|
launcher_id: typing.Union[int, str],
|
||||||
sender_id: typing.Union[int, str],
|
sender_id: typing.Union[int, str],
|
||||||
message_event: platform_events.MessageEvent,
|
message_event: platform_events.MessageEvent,
|
||||||
message_chain: platform_message.MessageChain,
|
message_chain: platform_message.MessageChain,
|
||||||
adapter: msadapter.MessagePlatformAdapter,
|
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
|
||||||
pipeline_uuid: typing.Optional[str] = None,
|
pipeline_uuid: typing.Optional[str] = None,
|
||||||
) -> entities.Query:
|
) -> pipeline_query.Query:
|
||||||
async with self.condition:
|
async with self.condition:
|
||||||
query = entities.Query(
|
query_id = self.query_id_counter
|
||||||
|
query = pipeline_query.Query(
|
||||||
bot_uuid=bot_uuid,
|
bot_uuid=bot_uuid,
|
||||||
query_id=self.query_id_counter,
|
query_id=query_id,
|
||||||
launcher_type=launcher_type,
|
launcher_type=launcher_type,
|
||||||
launcher_id=launcher_id,
|
launcher_id=launcher_id,
|
||||||
sender_id=sender_id,
|
sender_id=sender_id,
|
||||||
message_event=message_event,
|
message_event=message_event,
|
||||||
message_chain=message_chain,
|
message_chain=message_chain,
|
||||||
|
variables={},
|
||||||
resp_messages=[],
|
resp_messages=[],
|
||||||
resp_message_chain=[],
|
resp_message_chain=[],
|
||||||
adapter=adapter,
|
adapter=adapter,
|
||||||
pipeline_uuid=pipeline_uuid,
|
pipeline_uuid=pipeline_uuid,
|
||||||
)
|
)
|
||||||
self.queries.append(query)
|
self.queries.append(query)
|
||||||
|
self.cached_queries[query_id] = query
|
||||||
self.query_id_counter += 1
|
self.query_id_counter += 1
|
||||||
self.condition.notify_all()
|
self.condition.notify_all()
|
||||||
|
|
||||||
|
|||||||
@@ -3,10 +3,10 @@ from __future__ import annotations
|
|||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
from .. import stage, entities
|
from .. import stage, entities
|
||||||
from ...core import entities as core_entities
|
from langbot_plugin.api.entities.builtin.provider import message as provider_message
|
||||||
from ...provider import entities as llm_entities
|
import langbot_plugin.api.entities.events as events
|
||||||
from ...plugin import events
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
from ...platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@stage.stage_class('PreProcessor')
|
@stage.stage_class('PreProcessor')
|
||||||
@@ -26,7 +26,7 @@ class PreProcessor(stage.PipelineStage):
|
|||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
stage_inst_name: str,
|
stage_inst_name: str,
|
||||||
) -> entities.StageProcessResult:
|
) -> entities.StageProcessResult:
|
||||||
"""处理"""
|
"""处理"""
|
||||||
@@ -49,31 +49,31 @@ class PreProcessor(stage.PipelineStage):
|
|||||||
query.bot_uuid,
|
query.bot_uuid,
|
||||||
)
|
)
|
||||||
|
|
||||||
conversation.use_llm_model = llm_model
|
|
||||||
|
|
||||||
# 设置query
|
# 设置query
|
||||||
query.session = session
|
query.session = session
|
||||||
query.prompt = conversation.prompt.copy()
|
query.prompt = conversation.prompt.copy()
|
||||||
query.messages = conversation.messages.copy()
|
query.messages = conversation.messages.copy()
|
||||||
|
|
||||||
query.use_llm_model = llm_model
|
query.use_llm_model_uuid = llm_model.model_entity.uuid
|
||||||
|
|
||||||
if selected_runner == 'local-agent':
|
if selected_runner == 'local-agent':
|
||||||
query.use_funcs = (
|
query.use_funcs = []
|
||||||
conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('func_call') else None
|
|
||||||
)
|
|
||||||
|
|
||||||
query.variables = {
|
if llm_model.model_entity.abilities.__contains__('func_call'):
|
||||||
|
query.use_funcs = await self.ap.tool_mgr.get_all_tools()
|
||||||
|
|
||||||
|
variables = {
|
||||||
'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||||
'conversation_id': conversation.uuid,
|
'conversation_id': conversation.uuid,
|
||||||
'msg_create_time': (
|
'msg_create_time': (
|
||||||
int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp())
|
int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp())
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
query.variables.update(variables)
|
||||||
|
|
||||||
# Check if this model supports vision, if not, remove all images
|
# Check if this model supports vision, if not, remove all images
|
||||||
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
|
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
|
||||||
if selected_runner == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'):
|
if selected_runner == 'local-agent' and not llm_model.model_entity.abilities.__contains__('vision'):
|
||||||
for msg in query.messages:
|
for msg in query.messages:
|
||||||
if isinstance(msg.content, list):
|
if isinstance(msg.content, list):
|
||||||
for me in msg.content:
|
for me in msg.content:
|
||||||
@@ -87,39 +87,35 @@ class PreProcessor(stage.PipelineStage):
|
|||||||
|
|
||||||
for me in query.message_chain:
|
for me in query.message_chain:
|
||||||
if isinstance(me, platform_message.Plain):
|
if isinstance(me, platform_message.Plain):
|
||||||
content_list.append(llm_entities.ContentElement.from_text(me.text))
|
content_list.append(provider_message.ContentElement.from_text(me.text))
|
||||||
plain_text += me.text
|
plain_text += me.text
|
||||||
elif isinstance(me, platform_message.Image):
|
elif isinstance(me, platform_message.Image):
|
||||||
if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__(
|
if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
|
||||||
'vision'
|
|
||||||
):
|
|
||||||
if me.base64 is not None:
|
if me.base64 is not None:
|
||||||
content_list.append(llm_entities.ContentElement.from_image_base64(me.base64))
|
content_list.append(provider_message.ContentElement.from_image_base64(me.base64))
|
||||||
elif isinstance(me, platform_message.Quote) and qoute_msg:
|
elif isinstance(me, platform_message.Quote) and qoute_msg:
|
||||||
for msg in me.origin:
|
for msg in me.origin:
|
||||||
if isinstance(msg, platform_message.Plain):
|
if isinstance(msg, platform_message.Plain):
|
||||||
content_list.append(llm_entities.ContentElement.from_text(msg.text))
|
content_list.append(provider_message.ContentElement.from_text(msg.text))
|
||||||
elif isinstance(msg, platform_message.Image):
|
elif isinstance(msg, platform_message.Image):
|
||||||
if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__(
|
if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
|
||||||
'vision'
|
|
||||||
):
|
|
||||||
if msg.base64 is not None:
|
if msg.base64 is not None:
|
||||||
content_list.append(llm_entities.ContentElement.from_image_base64(msg.base64))
|
content_list.append(provider_message.ContentElement.from_image_base64(msg.base64))
|
||||||
|
|
||||||
query.variables['user_message_text'] = plain_text
|
query.variables['user_message_text'] = plain_text
|
||||||
|
|
||||||
query.user_message = llm_entities.Message(role='user', content=content_list)
|
query.user_message = provider_message.Message(role='user', content=content_list)
|
||||||
# =========== 触发事件 PromptPreProcessing
|
# =========== 触发事件 PromptPreProcessing
|
||||||
|
|
||||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
event = events.PromptPreProcessing(
|
||||||
event=events.PromptPreProcessing(
|
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||||
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
default_prompt=query.prompt.messages,
|
||||||
default_prompt=query.prompt.messages,
|
prompt=query.messages,
|
||||||
prompt=query.messages,
|
query=query,
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
event_ctx = await self.ap.plugin_connector.emit_event(event)
|
||||||
|
|
||||||
query.prompt.messages = event_ctx.event.default_prompt
|
query.prompt.messages = event_ctx.event.default_prompt
|
||||||
query.messages = event_ctx.event.prompt
|
query.messages = event_ctx.event.prompt
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
|||||||
import abc
|
import abc
|
||||||
|
|
||||||
from ...core import app
|
from ...core import app
|
||||||
from ...core import entities as core_entities
|
|
||||||
from .. import entities
|
from .. import entities
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
class MessageHandler(metaclass=abc.ABCMeta):
|
class MessageHandler(metaclass=abc.ABCMeta):
|
||||||
@@ -19,7 +19,7 @@ class MessageHandler(metaclass=abc.ABCMeta):
|
|||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def handle(
|
async def handle(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> entities.StageProcessResult:
|
) -> entities.StageProcessResult:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|||||||
@@ -6,13 +6,15 @@ import traceback
|
|||||||
|
|
||||||
from .. import handler
|
from .. import handler
|
||||||
from ... import entities
|
from ... import entities
|
||||||
from ....core import entities as core_entities
|
|
||||||
from ....provider import runner as runner_module
|
from ....provider import runner as runner_module
|
||||||
from ....plugin import events
|
|
||||||
|
|
||||||
from ....platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
|
import langbot_plugin.api.entities.events as events
|
||||||
from ....utils import importutil
|
from ....utils import importutil
|
||||||
from ....provider import runners
|
from ....provider import runners
|
||||||
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
importutil.import_modules_in_pkg(runners)
|
importutil.import_modules_in_pkg(runners)
|
||||||
|
|
||||||
@@ -20,7 +22,7 @@ importutil.import_modules_in_pkg(runners)
|
|||||||
class ChatMessageHandler(handler.MessageHandler):
|
class ChatMessageHandler(handler.MessageHandler):
|
||||||
async def handle(
|
async def handle(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||||
"""处理"""
|
"""处理"""
|
||||||
# 调API
|
# 调API
|
||||||
@@ -29,20 +31,20 @@ class ChatMessageHandler(handler.MessageHandler):
|
|||||||
# 触发插件事件
|
# 触发插件事件
|
||||||
event_class = (
|
event_class = (
|
||||||
events.PersonNormalMessageReceived
|
events.PersonNormalMessageReceived
|
||||||
if query.launcher_type == core_entities.LauncherTypes.PERSON
|
if query.launcher_type == provider_session.LauncherTypes.PERSON
|
||||||
else events.GroupNormalMessageReceived
|
else events.GroupNormalMessageReceived
|
||||||
)
|
)
|
||||||
|
|
||||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
event = event_class(
|
||||||
event=event_class(
|
launcher_type=query.launcher_type.value,
|
||||||
launcher_type=query.launcher_type.value,
|
launcher_id=query.launcher_id,
|
||||||
launcher_id=query.launcher_id,
|
sender_id=query.sender_id,
|
||||||
sender_id=query.sender_id,
|
text_message=str(query.message_chain),
|
||||||
text_message=str(query.message_chain),
|
query=query,
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
event_ctx = await self.ap.plugin_connector.emit_event(event)
|
||||||
|
|
||||||
if event_ctx.is_prevented_default():
|
if event_ctx.is_prevented_default():
|
||||||
if event_ctx.event.reply is not None:
|
if event_ctx.event.reply is not None:
|
||||||
mc = platform_message.MessageChain(event_ctx.event.reply)
|
mc = platform_message.MessageChain(event_ctx.event.reply)
|
||||||
|
|||||||
@@ -4,16 +4,17 @@ import typing
|
|||||||
|
|
||||||
from .. import handler
|
from .. import handler
|
||||||
from ... import entities
|
from ... import entities
|
||||||
from ....core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.provider.message as provider_message
|
||||||
from ....provider import entities as llm_entities
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
from ....plugin import events
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||||
from ....platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
import langbot_plugin.api.entities.events as events
|
||||||
|
|
||||||
|
|
||||||
class CommandHandler(handler.MessageHandler):
|
class CommandHandler(handler.MessageHandler):
|
||||||
async def handle(
|
async def handle(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||||
"""处理"""
|
"""处理"""
|
||||||
|
|
||||||
@@ -28,23 +29,23 @@ class CommandHandler(handler.MessageHandler):
|
|||||||
|
|
||||||
event_class = (
|
event_class = (
|
||||||
events.PersonCommandSent
|
events.PersonCommandSent
|
||||||
if query.launcher_type == core_entities.LauncherTypes.PERSON
|
if query.launcher_type == provider_session.LauncherTypes.PERSON
|
||||||
else events.GroupCommandSent
|
else events.GroupCommandSent
|
||||||
)
|
)
|
||||||
|
|
||||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
event = event_class(
|
||||||
event=event_class(
|
launcher_type=query.launcher_type.value,
|
||||||
launcher_type=query.launcher_type.value,
|
launcher_id=query.launcher_id,
|
||||||
launcher_id=query.launcher_id,
|
sender_id=query.sender_id,
|
||||||
sender_id=query.sender_id,
|
command=spt[0],
|
||||||
command=spt[0],
|
params=spt[1:] if len(spt) > 1 else [],
|
||||||
params=spt[1:] if len(spt) > 1 else [],
|
text_message=str(query.message_chain),
|
||||||
text_message=str(query.message_chain),
|
is_admin=(privilege == 2),
|
||||||
is_admin=(privilege == 2),
|
query=query,
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
event_ctx = await self.ap.plugin_connector.emit_event(event)
|
||||||
|
|
||||||
if event_ctx.is_prevented_default():
|
if event_ctx.is_prevented_default():
|
||||||
if event_ctx.event.reply is not None:
|
if event_ctx.event.reply is not None:
|
||||||
mc = platform_message.MessageChain(event_ctx.event.reply)
|
mc = platform_message.MessageChain(event_ctx.event.reply)
|
||||||
@@ -64,7 +65,7 @@ class CommandHandler(handler.MessageHandler):
|
|||||||
async for ret in self.ap.cmd_mgr.execute(command_text=command_text, query=query, session=session):
|
async for ret in self.ap.cmd_mgr.execute(command_text=command_text, query=query, session=session):
|
||||||
if ret.error is not None:
|
if ret.error is not None:
|
||||||
query.resp_messages.append(
|
query.resp_messages.append(
|
||||||
llm_entities.Message(
|
provider_message.Message(
|
||||||
role='command',
|
role='command',
|
||||||
content=str(ret.error),
|
content=str(ret.error),
|
||||||
)
|
)
|
||||||
@@ -74,16 +75,16 @@ class CommandHandler(handler.MessageHandler):
|
|||||||
|
|
||||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||||
elif ret.text is not None or ret.image_url is not None:
|
elif ret.text is not None or ret.image_url is not None:
|
||||||
content: list[llm_entities.ContentElement] = []
|
content: list[provider_message.ContentElement] = []
|
||||||
|
|
||||||
if ret.text is not None:
|
if ret.text is not None:
|
||||||
content.append(llm_entities.ContentElement.from_text(ret.text))
|
content.append(provider_message.ContentElement.from_text(ret.text))
|
||||||
|
|
||||||
if ret.image_url is not None:
|
if ret.image_url is not None:
|
||||||
content.append(llm_entities.ContentElement.from_image_url(ret.image_url))
|
content.append(provider_message.ContentElement.from_image_url(ret.image_url))
|
||||||
|
|
||||||
query.resp_messages.append(
|
query.resp_messages.append(
|
||||||
llm_entities.Message(
|
provider_message.Message(
|
||||||
role='command',
|
role='command',
|
||||||
content=content,
|
content=content,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from ...core import entities as core_entities
|
|
||||||
from . import handler
|
from . import handler
|
||||||
from .handlers import chat, command
|
from .handlers import chat, command
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from .. import stage
|
from .. import stage
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@stage.stage_class('MessageProcessor')
|
@stage.stage_class('MessageProcessor')
|
||||||
@@ -30,7 +30,7 @@ class Processor(stage.PipelineStage):
|
|||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
stage_inst_name: str,
|
stage_inst_name: str,
|
||||||
) -> entities.StageProcessResult:
|
) -> entities.StageProcessResult:
|
||||||
"""处理"""
|
"""处理"""
|
||||||
|
|||||||
@@ -2,7 +2,8 @@ from __future__ import annotations
|
|||||||
import abc
|
import abc
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ...core import app, entities as core_entities
|
from ...core import app
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
|
preregistered_algos: list[typing.Type[ReteLimitAlgo]] = []
|
||||||
@@ -33,7 +34,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
|||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def require_access(
|
async def require_access(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
launcher_type: str,
|
launcher_type: str,
|
||||||
launcher_id: typing.Union[int, str],
|
launcher_id: typing.Union[int, str],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@@ -53,7 +54,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
|
|||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def release_access(
|
async def release_access(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
launcher_type: str,
|
launcher_type: str,
|
||||||
launcher_id: typing.Union[int, str],
|
launcher_id: typing.Union[int, str],
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ import asyncio
|
|||||||
import time
|
import time
|
||||||
import typing
|
import typing
|
||||||
from .. import algo
|
from .. import algo
|
||||||
from ....core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
# 固定窗口算法
|
# 固定窗口算法
|
||||||
@@ -32,7 +32,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
|||||||
|
|
||||||
async def require_access(
|
async def require_access(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
launcher_type: str,
|
launcher_type: str,
|
||||||
launcher_id: typing.Union[int, str],
|
launcher_id: typing.Union[int, str],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@@ -91,7 +91,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
|
|||||||
|
|
||||||
async def release_access(
|
async def release_access(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
launcher_type: str,
|
launcher_type: str,
|
||||||
launcher_id: typing.Union[int, str],
|
launcher_id: typing.Union[int, str],
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -4,9 +4,10 @@ import typing
|
|||||||
|
|
||||||
from .. import entities, stage
|
from .. import entities, stage
|
||||||
from . import algo
|
from . import algo
|
||||||
from ...core import entities as core_entities
|
|
||||||
from ...utils import importutil
|
from ...utils import importutil
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
from . import algos
|
from . import algos
|
||||||
|
|
||||||
importutil.import_modules_in_pkg(algos)
|
importutil.import_modules_in_pkg(algos)
|
||||||
@@ -39,7 +40,7 @@ class RateLimit(stage.PipelineStage):
|
|||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
stage_inst_name: str,
|
stage_inst_name: str,
|
||||||
) -> typing.Union[
|
) -> typing.Union[
|
||||||
entities.StageProcessResult,
|
entities.StageProcessResult,
|
||||||
|
|||||||
@@ -4,18 +4,18 @@ import random
|
|||||||
import asyncio
|
import asyncio
|
||||||
|
|
||||||
|
|
||||||
from ...platform.types import events as platform_events
|
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||||
from ...platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
|
|
||||||
from .. import stage, entities
|
from .. import stage, entities
|
||||||
from ...core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@stage.stage_class('SendResponseBackStage')
|
@stage.stage_class('SendResponseBackStage')
|
||||||
class SendResponseBackStage(stage.PipelineStage):
|
class SendResponseBackStage(stage.PipelineStage):
|
||||||
"""发送响应消息"""
|
"""发送响应消息"""
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||||
"""处理"""
|
"""处理"""
|
||||||
|
|
||||||
random_range = (
|
random_range = (
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import pydantic.v1 as pydantic
|
import pydantic
|
||||||
|
|
||||||
from ...platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
|
|
||||||
|
|
||||||
class RuleJudgeResult(pydantic.BaseModel):
|
class RuleJudgeResult(pydantic.BaseModel):
|
||||||
|
|||||||
@@ -4,9 +4,10 @@ from __future__ import annotations
|
|||||||
from . import rule
|
from . import rule
|
||||||
|
|
||||||
from .. import stage, entities
|
from .. import stage, entities
|
||||||
from ...core import entities as core_entities
|
|
||||||
from ...utils import importutil
|
from ...utils import importutil
|
||||||
|
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
from . import rules
|
from . import rules
|
||||||
|
|
||||||
importutil.import_modules_in_pkg(rules)
|
importutil.import_modules_in_pkg(rules)
|
||||||
@@ -32,7 +33,7 @@ class GroupRespondRuleCheckStage(stage.PipelineStage):
|
|||||||
await rule_inst.initialize()
|
await rule_inst.initialize()
|
||||||
self.rule_matchers.append(rule_inst)
|
self.rule_matchers.append(rule_inst)
|
||||||
|
|
||||||
async def process(self, query: core_entities.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
async def process(self, query: pipeline_query.Query, stage_inst_name: str) -> entities.StageProcessResult:
|
||||||
if query.launcher_type.value != 'group': # 只处理群消息
|
if query.launcher_type.value != 'group': # 只处理群消息
|
||||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||||
|
|
||||||
|
|||||||
@@ -2,10 +2,11 @@ from __future__ import annotations
|
|||||||
import abc
|
import abc
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ...core import app, entities as core_entities
|
from ...core import app
|
||||||
from . import entities
|
from . import entities
|
||||||
|
|
||||||
from ...platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
|
preregisetered_rules: list[typing.Type[GroupRespondRule]] = []
|
||||||
@@ -39,7 +40,7 @@ class GroupRespondRule(metaclass=abc.ABCMeta):
|
|||||||
message_text: str,
|
message_text: str,
|
||||||
message_chain: platform_message.MessageChain,
|
message_chain: platform_message.MessageChain,
|
||||||
rule_dict: dict,
|
rule_dict: dict,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> entities.RuleJudgeResult:
|
) -> entities.RuleJudgeResult:
|
||||||
"""判断消息是否匹配规则"""
|
"""判断消息是否匹配规则"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from .. import rule as rule_model
|
from .. import rule as rule_model
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from ....core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
from ....platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@rule_model.rule_class('at-bot')
|
@rule_model.rule_class('at-bot')
|
||||||
@@ -14,7 +14,7 @@ class AtBotRule(rule_model.GroupRespondRule):
|
|||||||
message_text: str,
|
message_text: str,
|
||||||
message_chain: platform_message.MessageChain,
|
message_chain: platform_message.MessageChain,
|
||||||
rule_dict: dict,
|
rule_dict: dict,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> entities.RuleJudgeResult:
|
) -> entities.RuleJudgeResult:
|
||||||
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
|
if message_chain.has(platform_message.At(query.adapter.bot_account_id)) and rule_dict['at']:
|
||||||
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
|
message_chain.remove(platform_message.At(query.adapter.bot_account_id))
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
from .. import rule as rule_model
|
from .. import rule as rule_model
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from ....core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
from ....platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@rule_model.rule_class('prefix')
|
@rule_model.rule_class('prefix')
|
||||||
@@ -11,7 +11,7 @@ class PrefixRule(rule_model.GroupRespondRule):
|
|||||||
message_text: str,
|
message_text: str,
|
||||||
message_chain: platform_message.MessageChain,
|
message_chain: platform_message.MessageChain,
|
||||||
rule_dict: dict,
|
rule_dict: dict,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> entities.RuleJudgeResult:
|
) -> entities.RuleJudgeResult:
|
||||||
prefixes = rule_dict['prefix']
|
prefixes = rule_dict['prefix']
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ import random
|
|||||||
|
|
||||||
from .. import rule as rule_model
|
from .. import rule as rule_model
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from ....core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
from ....platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@rule_model.rule_class('random')
|
@rule_model.rule_class('random')
|
||||||
@@ -14,7 +14,7 @@ class RandomRespRule(rule_model.GroupRespondRule):
|
|||||||
message_text: str,
|
message_text: str,
|
||||||
message_chain: platform_message.MessageChain,
|
message_chain: platform_message.MessageChain,
|
||||||
rule_dict: dict,
|
rule_dict: dict,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> entities.RuleJudgeResult:
|
) -> entities.RuleJudgeResult:
|
||||||
random_rate = rule_dict['random']
|
random_rate = rule_dict['random']
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,8 @@ import re
|
|||||||
|
|
||||||
from .. import rule as rule_model
|
from .. import rule as rule_model
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from ....core import entities as core_entities
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
from ....platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
@rule_model.rule_class('regexp')
|
@rule_model.rule_class('regexp')
|
||||||
@@ -14,7 +14,7 @@ class RegExpRule(rule_model.GroupRespondRule):
|
|||||||
message_text: str,
|
message_text: str,
|
||||||
message_chain: platform_message.MessageChain,
|
message_chain: platform_message.MessageChain,
|
||||||
rule_dict: dict,
|
rule_dict: dict,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
) -> entities.RuleJudgeResult:
|
) -> entities.RuleJudgeResult:
|
||||||
regexps = rule_dict['regexp']
|
regexps = rule_dict['regexp']
|
||||||
|
|
||||||
|
|||||||
@@ -3,8 +3,9 @@ from __future__ import annotations
|
|||||||
import abc
|
import abc
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from ..core import app, entities as core_entities
|
from ..core import app
|
||||||
from . import entities
|
from . import entities
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
|
||||||
|
|
||||||
preregistered_stages: dict[str, type[PipelineStage]] = {}
|
preregistered_stages: dict[str, type[PipelineStage]] = {}
|
||||||
@@ -33,7 +34,7 @@ class PipelineStage(metaclass=abc.ABCMeta):
|
|||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
stage_inst_name: str,
|
stage_inst_name: str,
|
||||||
) -> typing.Union[
|
) -> typing.Union[
|
||||||
entities.StageProcessResult,
|
entities.StageProcessResult,
|
||||||
|
|||||||
@@ -2,12 +2,12 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
|
|
||||||
from ...core import entities as core_entities
|
|
||||||
from .. import entities
|
from .. import entities
|
||||||
from .. import stage
|
from .. import stage
|
||||||
from ...plugin import events
|
|
||||||
from ...platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
|
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
|
||||||
|
import langbot_plugin.api.entities.events as events
|
||||||
|
|
||||||
|
|
||||||
@stage.stage_class('ResponseWrapper')
|
@stage.stage_class('ResponseWrapper')
|
||||||
@@ -25,7 +25,7 @@ class ResponseWrapper(stage.PipelineStage):
|
|||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
query: core_entities.Query,
|
query: pipeline_query.Query,
|
||||||
stage_inst_name: str,
|
stage_inst_name: str,
|
||||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||||
"""处理"""
|
"""处理"""
|
||||||
@@ -58,21 +58,22 @@ class ResponseWrapper(stage.PipelineStage):
|
|||||||
reply_text = str(result.get_content_platform_message_chain())
|
reply_text = str(result.get_content_platform_message_chain())
|
||||||
|
|
||||||
# ============= 触发插件事件 ===============
|
# ============= 触发插件事件 ===============
|
||||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
event = events.NormalMessageResponded(
|
||||||
event=events.NormalMessageResponded(
|
launcher_type=query.launcher_type.value,
|
||||||
launcher_type=query.launcher_type.value,
|
launcher_id=query.launcher_id,
|
||||||
launcher_id=query.launcher_id,
|
sender_id=query.sender_id,
|
||||||
sender_id=query.sender_id,
|
session=session,
|
||||||
session=session,
|
prefix='',
|
||||||
prefix='',
|
response_text=reply_text,
|
||||||
response_text=reply_text,
|
finish_reason='stop',
|
||||||
finish_reason='stop',
|
funcs_called=[fc.function.name for fc in result.tool_calls]
|
||||||
funcs_called=[fc.function.name for fc in result.tool_calls]
|
if result.tool_calls is not None
|
||||||
if result.tool_calls is not None
|
else [],
|
||||||
else [],
|
query=query,
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
event_ctx = await self.ap.plugin_connector.emit_event(event)
|
||||||
|
|
||||||
if event_ctx.is_prevented_default():
|
if event_ctx.is_prevented_default():
|
||||||
yield entities.StageProcessResult(
|
yield entities.StageProcessResult(
|
||||||
result_type=entities.ResultType.INTERRUPT,
|
result_type=entities.ResultType.INTERRUPT,
|
||||||
@@ -96,26 +97,26 @@ class ResponseWrapper(stage.PipelineStage):
|
|||||||
reply_text = f'调用函数 {".".join(function_names)}...'
|
reply_text = f'调用函数 {".".join(function_names)}...'
|
||||||
|
|
||||||
query.resp_message_chain.append(
|
query.resp_message_chain.append(
|
||||||
platform_message.MessageChain([platform_message.Plain(reply_text)])
|
platform_message.MessageChain([platform_message.Plain(text=reply_text)])
|
||||||
)
|
)
|
||||||
|
|
||||||
if query.pipeline_config['output']['misc']['track-function-calls']:
|
if query.pipeline_config['output']['misc']['track-function-calls']:
|
||||||
event_ctx = await self.ap.plugin_mgr.emit_event(
|
event = events.NormalMessageResponded(
|
||||||
event=events.NormalMessageResponded(
|
launcher_type=query.launcher_type.value,
|
||||||
launcher_type=query.launcher_type.value,
|
launcher_id=query.launcher_id,
|
||||||
launcher_id=query.launcher_id,
|
sender_id=query.sender_id,
|
||||||
sender_id=query.sender_id,
|
session=session,
|
||||||
session=session,
|
prefix='',
|
||||||
prefix='',
|
response_text=reply_text,
|
||||||
response_text=reply_text,
|
finish_reason='stop',
|
||||||
finish_reason='stop',
|
funcs_called=[fc.function.name for fc in result.tool_calls]
|
||||||
funcs_called=[fc.function.name for fc in result.tool_calls]
|
if result.tool_calls is not None
|
||||||
if result.tool_calls is not None
|
else [],
|
||||||
else [],
|
query=query,
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
event_ctx = await self.ap.plugin_connector.emit_event(event)
|
||||||
|
|
||||||
if event_ctx.is_prevented_default():
|
if event_ctx.is_prevented_default():
|
||||||
yield entities.StageProcessResult(
|
yield entities.StageProcessResult(
|
||||||
result_type=entities.ResultType.INTERRUPT,
|
result_type=entities.ResultType.INTERRUPT,
|
||||||
@@ -124,12 +125,12 @@ class ResponseWrapper(stage.PipelineStage):
|
|||||||
else:
|
else:
|
||||||
if event_ctx.event.reply is not None:
|
if event_ctx.event.reply is not None:
|
||||||
query.resp_message_chain.append(
|
query.resp_message_chain.append(
|
||||||
platform_message.MessageChain(event_ctx.event.reply)
|
platform_message.MessageChain(text=event_ctx.event.reply)
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
query.resp_message_chain.append(
|
query.resp_message_chain.append(
|
||||||
platform_message.MessageChain([platform_message.Plain(reply_text)])
|
platform_message.MessageChain([platform_message.Plain(text=reply_text)])
|
||||||
)
|
)
|
||||||
|
|
||||||
yield entities.StageProcessResult(
|
yield entities.StageProcessResult(
|
||||||
|
|||||||
@@ -1,160 +0,0 @@
|
|||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
# MessageSource的适配器
|
|
||||||
import typing
|
|
||||||
import abc
|
|
||||||
|
|
||||||
|
|
||||||
from ..core import app
|
|
||||||
from .types import message as platform_message
|
|
||||||
from .types import events as platform_events
|
|
||||||
from .logger import EventLogger
|
|
||||||
|
|
||||||
|
|
||||||
class MessagePlatformAdapter(metaclass=abc.ABCMeta):
|
|
||||||
"""消息平台适配器基类"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
|
|
||||||
bot_account_id: int
|
|
||||||
"""机器人账号ID,需要在初始化时设置"""
|
|
||||||
|
|
||||||
config: dict
|
|
||||||
|
|
||||||
ap: app.Application
|
|
||||||
|
|
||||||
logger: EventLogger
|
|
||||||
|
|
||||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
|
||||||
"""初始化适配器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config (dict): 对应的配置
|
|
||||||
ap (app.Application): 应用上下文
|
|
||||||
"""
|
|
||||||
self.config = config
|
|
||||||
self.ap = ap
|
|
||||||
self.logger = logger
|
|
||||||
|
|
||||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
|
||||||
"""主动发送消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
target_type (str): 目标类型,`person`或`group`
|
|
||||||
target_id (str): 目标ID
|
|
||||||
message (platform.types.MessageChain): 消息链
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def reply_message(
|
|
||||||
self,
|
|
||||||
message_source: platform_events.MessageEvent,
|
|
||||||
message: platform_message.MessageChain,
|
|
||||||
quote_origin: bool = False,
|
|
||||||
):
|
|
||||||
"""回复消息
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_source (platform.types.MessageEvent): 消息源事件
|
|
||||||
message (platform.types.MessageChain): 消息链
|
|
||||||
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def is_muted(self, group_id: int) -> bool:
|
|
||||||
"""获取账号是否在指定群被禁言"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def register_listener(
|
|
||||||
self,
|
|
||||||
event_type: typing.Type[platform_message.Event],
|
|
||||||
callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None],
|
|
||||||
):
|
|
||||||
"""注册事件监听器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_type (typing.Type[platform.types.Event]): 事件类型
|
|
||||||
callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def unregister_listener(
|
|
||||||
self,
|
|
||||||
event_type: typing.Type[platform_message.Event],
|
|
||||||
callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None],
|
|
||||||
):
|
|
||||||
"""注销事件监听器
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event_type (typing.Type[platform.types.Event]): 事件类型
|
|
||||||
callback (typing.Callable[[platform.types.Event], None]): 回调函数,接收一个参数,为事件
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def run_async(self):
|
|
||||||
"""异步运行"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def kill(self) -> bool:
|
|
||||||
"""关闭适配器
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
bool: 是否成功关闭,热重载时若此函数返回False则不会重载MessageSource底层
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class MessageConverter:
|
|
||||||
"""消息链转换器基类"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def yiri2target(message_chain: platform_message.MessageChain):
|
|
||||||
"""将源平台消息链转换为目标平台消息链
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_chain (platform.types.MessageChain): 源平台消息链
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
typing.Any: 目标平台消息链
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def target2yiri(message_chain: typing.Any) -> platform_message.MessageChain:
|
|
||||||
"""将目标平台消息链转换为源平台消息链
|
|
||||||
|
|
||||||
Args:
|
|
||||||
message_chain (typing.Any): 目标平台消息链
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
platform.types.MessageChain: 源平台消息链
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
|
|
||||||
class EventConverter:
|
|
||||||
"""事件转换器基类"""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def yiri2target(event: typing.Type[platform_message.Event]):
|
|
||||||
"""将源平台事件转换为目标平台事件
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event (typing.Type[platform.types.Event]): 源平台事件
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
typing.Any: 目标平台事件
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def target2yiri(event: typing.Any) -> platform_message.Event:
|
|
||||||
"""将目标平台事件的调用参数转换为源平台的事件参数对象
|
|
||||||
|
|
||||||
Args:
|
|
||||||
event (typing.Any): 目标平台事件
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
typing.Type[platform.types.Event]: 源平台事件
|
|
||||||
"""
|
|
||||||
raise NotImplementedError
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
apiVersion: v1
|
|
||||||
kind: ComponentTemplate
|
|
||||||
metadata:
|
|
||||||
name: MessagePlatformAdapter
|
|
||||||
label:
|
|
||||||
en_US: Message Platform Adapter
|
|
||||||
zh_Hans: 消息平台适配器模板类
|
|
||||||
spec:
|
|
||||||
type:
|
|
||||||
- python
|
|
||||||
execution:
|
|
||||||
python:
|
|
||||||
path: ./adapter.py
|
|
||||||
attr: MessagePlatformAdapter
|
|
||||||
@@ -1,15 +1,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import sys
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import traceback
|
import traceback
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
|
||||||
# FriendMessage, Image, MessageChain, Plain
|
|
||||||
from . import adapter as msadapter
|
|
||||||
|
|
||||||
from ..core import app, entities as core_entities, taskmgr
|
from ..core import app, entities as core_entities, taskmgr
|
||||||
from .types import events as platform_events, message as platform_message
|
|
||||||
|
|
||||||
from ..discover import engine
|
from ..discover import engine
|
||||||
|
|
||||||
@@ -19,10 +14,10 @@ from ..entity.errors import platform as platform_errors
|
|||||||
|
|
||||||
from .logger import EventLogger
|
from .logger import EventLogger
|
||||||
|
|
||||||
# 处理 3.4 移除了 YiriMirai 之后,插件的兼容性问题
|
import langbot_plugin.api.entities.builtin.provider.session as provider_session
|
||||||
from . import types as mirai
|
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||||
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
sys.modules['mirai'] = mirai
|
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||||
|
|
||||||
|
|
||||||
class RuntimeBot:
|
class RuntimeBot:
|
||||||
@@ -34,7 +29,7 @@ class RuntimeBot:
|
|||||||
|
|
||||||
enable: bool
|
enable: bool
|
||||||
|
|
||||||
adapter: msadapter.MessagePlatformAdapter
|
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter
|
||||||
|
|
||||||
task_wrapper: taskmgr.TaskWrapper
|
task_wrapper: taskmgr.TaskWrapper
|
||||||
|
|
||||||
@@ -46,7 +41,7 @@ class RuntimeBot:
|
|||||||
self,
|
self,
|
||||||
ap: app.Application,
|
ap: app.Application,
|
||||||
bot_entity: persistence_bot.Bot,
|
bot_entity: persistence_bot.Bot,
|
||||||
adapter: msadapter.MessagePlatformAdapter,
|
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
|
||||||
logger: EventLogger,
|
logger: EventLogger,
|
||||||
):
|
):
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
@@ -59,7 +54,7 @@ class RuntimeBot:
|
|||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
async def on_friend_message(
|
async def on_friend_message(
|
||||||
event: platform_events.FriendMessage,
|
event: platform_events.FriendMessage,
|
||||||
adapter: msadapter.MessagePlatformAdapter,
|
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
|
||||||
):
|
):
|
||||||
image_components = [
|
image_components = [
|
||||||
component for component in event.message_chain if isinstance(component, platform_message.Image)
|
component for component in event.message_chain if isinstance(component, platform_message.Image)
|
||||||
@@ -73,7 +68,7 @@ class RuntimeBot:
|
|||||||
|
|
||||||
await self.ap.query_pool.add_query(
|
await self.ap.query_pool.add_query(
|
||||||
bot_uuid=self.bot_entity.uuid,
|
bot_uuid=self.bot_entity.uuid,
|
||||||
launcher_type=core_entities.LauncherTypes.PERSON,
|
launcher_type=provider_session.LauncherTypes.PERSON,
|
||||||
launcher_id=event.sender.id,
|
launcher_id=event.sender.id,
|
||||||
sender_id=event.sender.id,
|
sender_id=event.sender.id,
|
||||||
message_event=event,
|
message_event=event,
|
||||||
@@ -84,7 +79,7 @@ class RuntimeBot:
|
|||||||
|
|
||||||
async def on_group_message(
|
async def on_group_message(
|
||||||
event: platform_events.GroupMessage,
|
event: platform_events.GroupMessage,
|
||||||
adapter: msadapter.MessagePlatformAdapter,
|
adapter: abstract_platform_adapter.AbstractMessagePlatformAdapter,
|
||||||
):
|
):
|
||||||
image_components = [
|
image_components = [
|
||||||
component for component in event.message_chain if isinstance(component, platform_message.Image)
|
component for component in event.message_chain if isinstance(component, platform_message.Image)
|
||||||
@@ -98,7 +93,7 @@ class RuntimeBot:
|
|||||||
|
|
||||||
await self.ap.query_pool.add_query(
|
await self.ap.query_pool.add_query(
|
||||||
bot_uuid=self.bot_entity.uuid,
|
bot_uuid=self.bot_entity.uuid,
|
||||||
launcher_type=core_entities.LauncherTypes.GROUP,
|
launcher_type=provider_session.LauncherTypes.GROUP,
|
||||||
launcher_id=event.group.id,
|
launcher_id=event.group.id,
|
||||||
sender_id=event.sender.id,
|
sender_id=event.sender.id,
|
||||||
message_event=event,
|
message_event=event,
|
||||||
@@ -151,7 +146,7 @@ class PlatformManager:
|
|||||||
|
|
||||||
adapter_components: list[engine.Component]
|
adapter_components: list[engine.Component]
|
||||||
|
|
||||||
adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]]
|
adapter_dict: dict[str, type[abstract_platform_adapter.AbstractMessagePlatformAdapter]]
|
||||||
|
|
||||||
def __init__(self, ap: app.Application = None):
|
def __init__(self, ap: app.Application = None):
|
||||||
self.ap = ap
|
self.ap = ap
|
||||||
@@ -161,7 +156,7 @@ class PlatformManager:
|
|||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
self.adapter_components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter')
|
self.adapter_components = self.ap.discover.get_components_by_kind('MessagePlatformAdapter')
|
||||||
adapter_dict: dict[str, type[msadapter.MessagePlatformAdapter]] = {}
|
adapter_dict: dict[str, type[abstract_platform_adapter.AbstractMessagePlatformAdapter]] = {}
|
||||||
for component in self.adapter_components:
|
for component in self.adapter_components:
|
||||||
adapter_dict[component.metadata.name] = component.get_python_component_class()
|
adapter_dict[component.metadata.name] = component.get_python_component_class()
|
||||||
self.adapter_dict = adapter_dict
|
self.adapter_dict = adapter_dict
|
||||||
@@ -172,9 +167,10 @@ class PlatformManager:
|
|||||||
webchat_logger = EventLogger(name='webchat-adapter', ap=self.ap)
|
webchat_logger = EventLogger(name='webchat-adapter', ap=self.ap)
|
||||||
webchat_adapter_inst = webchat_adapter_class(
|
webchat_adapter_inst = webchat_adapter_class(
|
||||||
{},
|
{},
|
||||||
self.ap,
|
|
||||||
webchat_logger,
|
webchat_logger,
|
||||||
|
ap=self.ap,
|
||||||
)
|
)
|
||||||
|
webchat_adapter_inst.ap = self.ap
|
||||||
|
|
||||||
self.webchat_proxy_bot = RuntimeBot(
|
self.webchat_proxy_bot = RuntimeBot(
|
||||||
ap=self.ap,
|
ap=self.ap,
|
||||||
@@ -193,7 +189,7 @@ class PlatformManager:
|
|||||||
|
|
||||||
await self.load_bots_from_db()
|
await self.load_bots_from_db()
|
||||||
|
|
||||||
def get_running_adapters(self) -> list[msadapter.MessagePlatformAdapter]:
|
def get_running_adapters(self) -> list[abstract_platform_adapter.AbstractMessagePlatformAdapter]:
|
||||||
return [bot.adapter for bot in self.bots if bot.enable]
|
return [bot.adapter for bot in self.bots if bot.enable]
|
||||||
|
|
||||||
async def load_bots_from_db(self):
|
async def load_bots_from_db(self):
|
||||||
@@ -231,7 +227,6 @@ class PlatformManager:
|
|||||||
|
|
||||||
adapter_inst = self.adapter_dict[bot_entity.adapter](
|
adapter_inst = self.adapter_dict[bot_entity.adapter](
|
||||||
bot_entity.adapter_config,
|
bot_entity.adapter_config,
|
||||||
self.ap,
|
|
||||||
logger,
|
logger,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -274,43 +269,6 @@ class PlatformManager:
|
|||||||
return component
|
return component
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def write_back_config(
|
|
||||||
self,
|
|
||||||
adapter_name: str,
|
|
||||||
adapter_inst: msadapter.MessagePlatformAdapter,
|
|
||||||
config: dict,
|
|
||||||
):
|
|
||||||
# index = -2
|
|
||||||
|
|
||||||
# for i, adapter in enumerate(self.adapters):
|
|
||||||
# if adapter == adapter_inst:
|
|
||||||
# index = i
|
|
||||||
# break
|
|
||||||
|
|
||||||
# if index == -2:
|
|
||||||
# raise Exception('平台适配器未找到')
|
|
||||||
|
|
||||||
# # 只修改启用的适配器
|
|
||||||
# real_index = -1
|
|
||||||
|
|
||||||
# for i, adapter in enumerate(self.ap.platform_cfg.data['platform-adapters']):
|
|
||||||
# if adapter['enable']:
|
|
||||||
# index -= 1
|
|
||||||
# if index == -1:
|
|
||||||
# real_index = i
|
|
||||||
# break
|
|
||||||
|
|
||||||
# new_cfg = {
|
|
||||||
# 'adapter': adapter_name,
|
|
||||||
# 'enable': True,
|
|
||||||
# **config
|
|
||||||
# }
|
|
||||||
# self.ap.platform_cfg.data['platform-adapters'][real_index] = new_cfg
|
|
||||||
# await self.ap.platform_cfg.dump_config()
|
|
||||||
|
|
||||||
# TODO implement this
|
|
||||||
pass
|
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
# This method will only be called when the application launching
|
# This method will only be called when the application launching
|
||||||
await self.webchat_proxy_bot.run()
|
await self.webchat_proxy_bot.run()
|
||||||
|
|||||||
@@ -9,7 +9,8 @@ import traceback
|
|||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
from ..core import app
|
from ..core import app
|
||||||
from .types import message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
|
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_event_logger
|
||||||
|
|
||||||
|
|
||||||
class EventLogLevel(enum.Enum):
|
class EventLogLevel(enum.Enum):
|
||||||
@@ -55,7 +56,7 @@ MAX_LOG_COUNT = 200
|
|||||||
DELETE_COUNT_PER_TIME = 50
|
DELETE_COUNT_PER_TIME = 50
|
||||||
|
|
||||||
|
|
||||||
class EventLogger:
|
class EventLogger(abstract_platform_event_logger.AbstractEventLogger):
|
||||||
"""used for logging bot events"""
|
"""used for logging bot events"""
|
||||||
|
|
||||||
ap: app.Application
|
ap: app.Application
|
||||||
|
|||||||
@@ -5,18 +5,17 @@ import traceback
|
|||||||
import datetime
|
import datetime
|
||||||
|
|
||||||
import aiocqhttp
|
import aiocqhttp
|
||||||
|
import pydantic
|
||||||
|
|
||||||
from .. import adapter
|
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||||
from ...core import app
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
from ..types import message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||||
from ..types import events as platform_events
|
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||||
from ..types import entities as platform_entities
|
|
||||||
from ...utils import image
|
from ...utils import image
|
||||||
from ..logger import EventLogger
|
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
|
||||||
|
|
||||||
|
|
||||||
class AiocqhttpMessageConverter(adapter.MessageConverter):
|
class AiocqhttpMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def yiri2target(
|
async def yiri2target(
|
||||||
message_chain: platform_message.MessageChain,
|
message_chain: platform_message.MessageChain,
|
||||||
@@ -71,15 +70,13 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
|||||||
elif msg.face_type=='dice':
|
elif msg.face_type=='dice':
|
||||||
msg_list.append(aiocqhttp.MessageSegment.dice())
|
msg_list.append(aiocqhttp.MessageSegment.dice())
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
msg_list.append(aiocqhttp.MessageSegment.text(str(msg)))
|
msg_list.append(aiocqhttp.MessageSegment.text(str(msg)))
|
||||||
|
|
||||||
return msg_list, msg_id, msg_time
|
return msg_list, msg_id, msg_time
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def target2yiri(message: str, message_id: int = -1,bot=None):
|
async def target2yiri(message: str, message_id: int = -1, bot=None):
|
||||||
print(message)
|
|
||||||
message = aiocqhttp.Message(message)
|
message = aiocqhttp.Message(message)
|
||||||
|
|
||||||
def get_face_name(face_id):
|
def get_face_name(face_id):
|
||||||
@@ -119,30 +116,28 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
|||||||
return face_code_dict.get(face_id,'')
|
return face_code_dict.get(face_id,'')
|
||||||
|
|
||||||
async def process_message_data(msg_data, reply_list):
|
async def process_message_data(msg_data, reply_list):
|
||||||
if msg_data["type"] == "image":
|
if msg_data['type'] == 'image':
|
||||||
image_base64, image_format = await image.qq_image_url_to_base64(msg_data["data"]['url'])
|
image_base64, image_format = await image.qq_image_url_to_base64(msg_data['data']['url'])
|
||||||
reply_list.append(
|
reply_list.append(platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}'))
|
||||||
platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}'))
|
|
||||||
|
|
||||||
elif msg_data["type"] == "text":
|
elif msg_data['type'] == 'text':
|
||||||
reply_list.append(platform_message.Plain(text=msg_data["data"]["text"]))
|
reply_list.append(platform_message.Plain(text=msg_data['data']['text']))
|
||||||
|
|
||||||
elif msg_data["type"] == "forward": # 这里来应该传入转发消息组,暂时传入qoute
|
elif msg_data['type'] == 'forward': # 这里来应该传入转发消息组,暂时传入qoute
|
||||||
for forward_msg_datas in msg_data["data"]["content"]:
|
for forward_msg_datas in msg_data['data']['content']:
|
||||||
for forward_msg_data in forward_msg_datas["message"]:
|
for forward_msg_data in forward_msg_datas['message']:
|
||||||
await process_message_data(forward_msg_data, reply_list)
|
await process_message_data(forward_msg_data, reply_list)
|
||||||
|
|
||||||
elif msg_data["type"] == "at":
|
elif msg_data['type'] == 'at':
|
||||||
if msg_data["data"]['qq'] == 'all':
|
if msg_data['data']['qq'] == 'all':
|
||||||
reply_list.append(platform_message.AtAll())
|
reply_list.append(platform_message.AtAll())
|
||||||
else:
|
else:
|
||||||
reply_list.append(
|
reply_list.append(
|
||||||
platform_message.At(
|
platform_message.At(
|
||||||
target=msg_data["data"]['qq'],
|
target=msg_data['data']['qq'],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
yiri_msg_list = []
|
yiri_msg_list = []
|
||||||
|
|
||||||
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
|
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
|
||||||
@@ -178,14 +173,15 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
|||||||
# await process_message_data(msg_data, yiri_msg_list)
|
# await process_message_data(msg_data, yiri_msg_list)
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
elif msg.type == 'reply': # 此处处理引用消息传入Qoute
|
elif msg.type == 'reply': # 此处处理引用消息传入Qoute
|
||||||
msg_datas = await bot.get_msg(message_id=msg.data["id"])
|
msg_datas = await bot.get_msg(message_id=msg.data['id'])
|
||||||
|
|
||||||
for msg_data in msg_datas["message"]:
|
for msg_data in msg_datas['message']:
|
||||||
await process_message_data(msg_data, reply_list)
|
await process_message_data(msg_data, reply_list)
|
||||||
|
|
||||||
reply_msg = platform_message.Quote(message_id=msg.data["id"],sender_id=msg_datas["user_id"],origin=reply_list)
|
reply_msg = platform_message.Quote(
|
||||||
|
message_id=msg.data['id'], sender_id=msg_datas['user_id'], origin=reply_list
|
||||||
|
)
|
||||||
yiri_msg_list.append(reply_msg)
|
yiri_msg_list.append(reply_msg)
|
||||||
|
|
||||||
elif msg.type == 'file':
|
elif msg.type == 'file':
|
||||||
@@ -194,6 +190,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
|||||||
file_data = await bot.get_file(file_id=file_id)
|
file_data = await bot.get_file(file_id=file_id)
|
||||||
file_name = file_data.get('file_name')
|
file_name = file_data.get('file_name')
|
||||||
file_path = file_data.get('file')
|
file_path = file_data.get('file')
|
||||||
|
_ = file_path
|
||||||
file_url = file_data.get('file_url')
|
file_url = file_data.get('file_url')
|
||||||
file_size = file_data.get('file_size')
|
file_size = file_data.get('file_size')
|
||||||
yiri_msg_list.append(platform_message.File(id=file_id, name=file_name,url=file_url,size=file_size))
|
yiri_msg_list.append(platform_message.File(id=file_id, name=file_name,url=file_url,size=file_size))
|
||||||
@@ -210,32 +207,19 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
|
|||||||
face_id = msg.data['result']
|
face_id = msg.data['result']
|
||||||
yiri_msg_list.append(platform_message.Face(face_type='dice',face_id=int(face_id),face_name='骰子'))
|
yiri_msg_list.append(platform_message.Face(face_type='dice',face_id=int(face_id),face_name='骰子'))
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
chain = platform_message.MessageChain(yiri_msg_list)
|
chain = platform_message.MessageChain(yiri_msg_list)
|
||||||
|
|
||||||
return chain
|
return chain
|
||||||
|
|
||||||
|
|
||||||
|
class AiocqhttpEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AiocqhttpEventConverter(adapter.EventConverter):
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int):
|
async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int):
|
||||||
return event.source_platform_object
|
return event.source_platform_object
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def target2yiri(event: aiocqhttp.Event,bot=None):
|
async def target2yiri(event: aiocqhttp.Event, bot=None):
|
||||||
yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id,bot)
|
yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id, bot)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if event.message_type == 'group':
|
if event.message_type == 'group':
|
||||||
@@ -279,23 +263,19 @@ class AiocqhttpEventConverter(adapter.EventConverter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
class AiocqhttpAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||||
bot: aiocqhttp.CQHttp
|
bot: aiocqhttp.CQHttp = pydantic.Field(exclude=True, default_factory=aiocqhttp.CQHttp)
|
||||||
|
|
||||||
bot_account_id: int
|
|
||||||
|
|
||||||
message_converter: AiocqhttpMessageConverter = AiocqhttpMessageConverter()
|
message_converter: AiocqhttpMessageConverter = AiocqhttpMessageConverter()
|
||||||
event_converter: AiocqhttpEventConverter = AiocqhttpEventConverter()
|
event_converter: AiocqhttpEventConverter = AiocqhttpEventConverter()
|
||||||
|
|
||||||
config: dict
|
|
||||||
|
|
||||||
ap: app.Application
|
|
||||||
|
|
||||||
on_websocket_connection_event_cache: typing.List[typing.Callable[[aiocqhttp.Event], None]] = []
|
on_websocket_connection_event_cache: typing.List[typing.Callable[[aiocqhttp.Event], None]] = []
|
||||||
|
|
||||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
|
||||||
self.config = config
|
super().__init__(
|
||||||
self.logger = logger
|
config=config,
|
||||||
|
logger=logger,
|
||||||
|
)
|
||||||
|
|
||||||
async def shutdown_trigger_placeholder():
|
async def shutdown_trigger_placeholder():
|
||||||
while True:
|
while True:
|
||||||
@@ -303,7 +283,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
|||||||
|
|
||||||
self.config['shutdown_trigger'] = shutdown_trigger_placeholder
|
self.config['shutdown_trigger'] = shutdown_trigger_placeholder
|
||||||
|
|
||||||
self.ap = ap
|
|
||||||
self.on_websocket_connection_event_cache = []
|
self.on_websocket_connection_event_cache = []
|
||||||
|
|
||||||
if 'access-token' in config:
|
if 'access-token' in config:
|
||||||
@@ -316,7 +295,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
|||||||
aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0]
|
aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0]
|
||||||
|
|
||||||
if target_type == 'group':
|
if target_type == 'group':
|
||||||
|
|
||||||
await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg)
|
await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg)
|
||||||
elif target_type == 'person':
|
elif target_type == 'person':
|
||||||
await self.bot.send_private_msg(user_id=int(target_id), message=aiocq_msg)
|
await self.bot.send_private_msg(user_id=int(target_id), message=aiocq_msg)
|
||||||
@@ -340,12 +318,14 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
|||||||
def register_listener(
|
def register_listener(
|
||||||
self,
|
self,
|
||||||
event_type: typing.Type[platform_events.Event],
|
event_type: typing.Type[platform_events.Event],
|
||||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
callback: typing.Callable[
|
||||||
|
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||||
|
],
|
||||||
):
|
):
|
||||||
async def on_message(event: aiocqhttp.Event):
|
async def on_message(event: aiocqhttp.Event):
|
||||||
self.bot_account_id = event.self_id
|
self.bot_account_id = event.self_id
|
||||||
try:
|
try:
|
||||||
return await callback(await self.event_converter.target2yiri(event,self.bot), self)
|
return await callback(await self.event_converter.target2yiri(event, self.bot), self)
|
||||||
except Exception:
|
except Exception:
|
||||||
await self.logger.error(f'Error in on_message: {traceback.format_exc()}')
|
await self.logger.error(f'Error in on_message: {traceback.format_exc()}')
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
@@ -371,7 +351,9 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
|
|||||||
def unregister_listener(
|
def unregister_listener(
|
||||||
self,
|
self,
|
||||||
event_type: typing.Type[platform_events.Event],
|
event_type: typing.Type[platform_events.Event],
|
||||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
callback: typing.Callable[
|
||||||
|
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||||
|
],
|
||||||
):
|
):
|
||||||
return super().unregister_listener(event_type, callback)
|
return super().unregister_listener(event_type, callback)
|
||||||
|
|
||||||
|
|||||||
@@ -1,18 +1,16 @@
|
|||||||
import traceback
|
import traceback
|
||||||
import typing
|
import typing
|
||||||
from libs.dingtalk_api.dingtalkevent import DingTalkEvent
|
from libs.dingtalk_api.dingtalkevent import DingTalkEvent
|
||||||
from pkg.platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
from pkg.platform.adapter import MessagePlatformAdapter
|
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||||
from .. import adapter
|
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||||
from ...core import app
|
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||||
from ..types import events as platform_events
|
|
||||||
from ..types import entities as platform_entities
|
|
||||||
from libs.dingtalk_api.api import DingTalkClient
|
from libs.dingtalk_api.api import DingTalkClient
|
||||||
import datetime
|
import datetime
|
||||||
from ..logger import EventLogger
|
from ..logger import EventLogger
|
||||||
|
|
||||||
|
|
||||||
class DingTalkMessageConverter(adapter.MessageConverter):
|
class DingTalkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def yiri2target(message_chain: platform_message.MessageChain):
|
async def yiri2target(message_chain: platform_message.MessageChain):
|
||||||
content = ''
|
content = ''
|
||||||
@@ -48,7 +46,7 @@ class DingTalkMessageConverter(adapter.MessageConverter):
|
|||||||
return chain
|
return chain
|
||||||
|
|
||||||
|
|
||||||
class DingTalkEventConverter(adapter.EventConverter):
|
class DingTalkEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def yiri2target(event: platform_events.MessageEvent):
|
async def yiri2target(event: platform_events.MessageEvent):
|
||||||
return event.source_platform_object
|
return event.source_platform_object
|
||||||
@@ -92,17 +90,15 @@ class DingTalkEventConverter(adapter.EventConverter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
class DingTalkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||||
bot: DingTalkClient
|
bot: DingTalkClient
|
||||||
ap: app.Application
|
|
||||||
bot_account_id: str
|
bot_account_id: str
|
||||||
message_converter: DingTalkMessageConverter = DingTalkMessageConverter()
|
message_converter: DingTalkMessageConverter = DingTalkMessageConverter()
|
||||||
event_converter: DingTalkEventConverter = DingTalkEventConverter()
|
event_converter: DingTalkEventConverter = DingTalkEventConverter()
|
||||||
config: dict
|
config: dict
|
||||||
|
|
||||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
def __init__(self, config: dict, logger: EventLogger):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.ap = ap
|
|
||||||
self.logger = logger
|
self.logger = logger
|
||||||
required_keys = [
|
required_keys = [
|
||||||
'client_id',
|
'client_id',
|
||||||
@@ -140,7 +136,9 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
|||||||
def register_listener(
|
def register_listener(
|
||||||
self,
|
self,
|
||||||
event_type: typing.Type[platform_events.Event],
|
event_type: typing.Type[platform_events.Event],
|
||||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
callback: typing.Callable[
|
||||||
|
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||||
|
],
|
||||||
):
|
):
|
||||||
async def on_message(event: DingTalkEvent):
|
async def on_message(event: DingTalkEvent):
|
||||||
try:
|
try:
|
||||||
@@ -174,6 +172,8 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
|||||||
async def unregister_listener(
|
async def unregister_listener(
|
||||||
self,
|
self,
|
||||||
event_type: type,
|
event_type: type,
|
||||||
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
|
callback: typing.Callable[
|
||||||
|
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||||
|
],
|
||||||
):
|
):
|
||||||
return super().unregister_listener(event_type, callback)
|
return super().unregister_listener(event_type, callback)
|
||||||
|
|||||||
@@ -8,19 +8,18 @@ import base64
|
|||||||
import uuid
|
import uuid
|
||||||
import os
|
import os
|
||||||
import datetime
|
import datetime
|
||||||
import io
|
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
import pydantic
|
||||||
|
|
||||||
from .. import adapter
|
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||||
from ...core import app
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
from ..types import message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||||
from ..types import events as platform_events
|
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||||
from ..types import entities as platform_entities
|
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
|
||||||
from ..logger import EventLogger
|
|
||||||
|
|
||||||
|
|
||||||
class DiscordMessageConverter(adapter.MessageConverter):
|
class DiscordMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def yiri2target(
|
async def yiri2target(
|
||||||
message_chain: platform_message.MessageChain,
|
message_chain: platform_message.MessageChain,
|
||||||
@@ -36,88 +35,28 @@ class DiscordMessageConverter(adapter.MessageConverter):
|
|||||||
for ele in message_chain:
|
for ele in message_chain:
|
||||||
if isinstance(ele, platform_message.Image):
|
if isinstance(ele, platform_message.Image):
|
||||||
image_bytes = None
|
image_bytes = None
|
||||||
filename = f'{uuid.uuid4()}.png' # 默认文件名
|
|
||||||
|
|
||||||
if ele.base64:
|
if ele.base64:
|
||||||
# 处理base64编码的图片
|
image_bytes = base64.b64decode(ele.base64)
|
||||||
if ele.base64.startswith('data:'):
|
|
||||||
# 从data URL中提取文件类型
|
|
||||||
data_header = ele.base64.split(',')[0]
|
|
||||||
if 'jpeg' in data_header or 'jpg' in data_header:
|
|
||||||
filename = f'{uuid.uuid4()}.jpg'
|
|
||||||
elif 'gif' in data_header:
|
|
||||||
filename = f'{uuid.uuid4()}.gif'
|
|
||||||
elif 'webp' in data_header:
|
|
||||||
filename = f'{uuid.uuid4()}.webp'
|
|
||||||
# 去掉data:image/xxx;base64,前缀
|
|
||||||
base64_data = ele.base64.split(',')[1]
|
|
||||||
else:
|
|
||||||
base64_data = ele.base64
|
|
||||||
image_bytes = base64.b64decode(base64_data)
|
|
||||||
elif ele.url:
|
elif ele.url:
|
||||||
# 从URL下载图片
|
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
async with session.get(ele.url) as response:
|
async with session.get(ele.url) as response:
|
||||||
image_bytes = await response.read()
|
image_bytes = await response.read()
|
||||||
# 从URL或Content-Type推断文件类型
|
|
||||||
content_type = response.headers.get('Content-Type', '')
|
|
||||||
if 'jpeg' in content_type or 'jpg' in content_type:
|
|
||||||
filename = f'{uuid.uuid4()}.jpg'
|
|
||||||
elif 'gif' in content_type:
|
|
||||||
filename = f'{uuid.uuid4()}.gif'
|
|
||||||
elif 'webp' in content_type:
|
|
||||||
filename = f'{uuid.uuid4()}.webp'
|
|
||||||
elif ele.url.lower().endswith(('.jpg', '.jpeg')):
|
|
||||||
filename = f'{uuid.uuid4()}.jpg'
|
|
||||||
elif ele.url.lower().endswith('.gif'):
|
|
||||||
filename = f'{uuid.uuid4()}.gif'
|
|
||||||
elif ele.url.lower().endswith('.webp'):
|
|
||||||
filename = f'{uuid.uuid4()}.webp'
|
|
||||||
elif ele.path:
|
elif ele.path:
|
||||||
# 从文件路径读取图片
|
with open(ele.path, 'rb') as f:
|
||||||
# 确保路径没有空字节
|
image_bytes = f.read()
|
||||||
clean_path = ele.path.replace('\x00', '')
|
|
||||||
clean_path = os.path.abspath(clean_path)
|
|
||||||
|
|
||||||
if not os.path.exists(clean_path):
|
|
||||||
continue # 跳过不存在的文件
|
|
||||||
|
|
||||||
try:
|
|
||||||
with open(clean_path, 'rb') as f:
|
|
||||||
image_bytes = f.read()
|
|
||||||
# 从文件路径获取文件名,保持原始扩展名
|
|
||||||
original_filename = os.path.basename(clean_path)
|
|
||||||
if original_filename and '.' in original_filename:
|
|
||||||
# 保持原始文件名的扩展名
|
|
||||||
ext = original_filename.split('.')[-1].lower()
|
|
||||||
filename = f'{uuid.uuid4()}.{ext}'
|
|
||||||
else:
|
|
||||||
# 如果没有扩展名,尝试从文件内容检测
|
|
||||||
if image_bytes.startswith(b'\xff\xd8\xff'):
|
|
||||||
filename = f'{uuid.uuid4()}.jpg'
|
|
||||||
elif image_bytes.startswith(b'GIF'):
|
|
||||||
filename = f'{uuid.uuid4()}.gif'
|
|
||||||
elif image_bytes.startswith(b'RIFF') and b'WEBP' in image_bytes[:20]:
|
|
||||||
filename = f'{uuid.uuid4()}.webp'
|
|
||||||
# 默认保持PNG
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error reading image file {clean_path}: {e}")
|
|
||||||
continue # 跳过读取失败的文件
|
|
||||||
|
|
||||||
if image_bytes:
|
image_files.append(discord.File(fp=image_bytes, filename=f'{uuid.uuid4()}.png'))
|
||||||
# 使用BytesIO创建文件对象,避免路径问题
|
|
||||||
import io
|
|
||||||
image_files.append(discord.File(fp=io.BytesIO(image_bytes), filename=filename))
|
|
||||||
elif isinstance(ele, platform_message.Plain):
|
elif isinstance(ele, platform_message.Plain):
|
||||||
text_string += ele.text
|
text_string += ele.text
|
||||||
elif isinstance(ele, platform_message.Forward):
|
elif isinstance(ele, platform_message.Forward):
|
||||||
for node in ele.node_list:
|
for node in ele.node_list:
|
||||||
(
|
(
|
||||||
node_text,
|
text_string,
|
||||||
node_images,
|
image_files,
|
||||||
) = await DiscordMessageConverter.yiri2target(node.message_chain)
|
) = await DiscordMessageConverter.yiri2target(node.message_chain)
|
||||||
text_string += node_text
|
text_string += text_string
|
||||||
image_files.extend(node_images)
|
image_files.extend(image_files)
|
||||||
|
|
||||||
return text_string, image_files
|
return text_string, image_files
|
||||||
|
|
||||||
@@ -173,7 +112,7 @@ class DiscordMessageConverter(adapter.MessageConverter):
|
|||||||
return platform_message.MessageChain(element_list)
|
return platform_message.MessageChain(element_list)
|
||||||
|
|
||||||
|
|
||||||
class DiscordEventConverter(adapter.EventConverter):
|
class DiscordEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def yiri2target(event: platform_events.Event) -> discord.Message:
|
async def yiri2target(event: platform_events.Event) -> discord.Message:
|
||||||
pass
|
pass
|
||||||
@@ -215,29 +154,21 @@ class DiscordEventConverter(adapter.EventConverter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class DiscordAdapter(adapter.MessagePlatformAdapter):
|
class DiscordAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||||
bot: discord.Client
|
bot: discord.Client = pydantic.Field(exclude=True)
|
||||||
|
|
||||||
bot_account_id: str # 用于在流水线中识别at是否是本bot,直接以bot_name作为标识
|
|
||||||
|
|
||||||
config: dict
|
|
||||||
|
|
||||||
ap: app.Application
|
|
||||||
|
|
||||||
message_converter: DiscordMessageConverter = DiscordMessageConverter()
|
message_converter: DiscordMessageConverter = DiscordMessageConverter()
|
||||||
event_converter: DiscordEventConverter = DiscordEventConverter()
|
event_converter: DiscordEventConverter = DiscordEventConverter()
|
||||||
|
|
||||||
listeners: typing.Dict[
|
listeners: typing.Dict[
|
||||||
typing.Type[platform_events.Event],
|
typing.Type[platform_events.Event],
|
||||||
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
|
||||||
] = {}
|
] = {}
|
||||||
|
|
||||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
|
||||||
self.config = config
|
bot_account_id = config['client_id']
|
||||||
self.ap = ap
|
|
||||||
self.logger = logger
|
|
||||||
|
|
||||||
self.bot_account_id = self.config['client_id']
|
listeners = {}
|
||||||
|
|
||||||
adapter_self = self
|
adapter_self = self
|
||||||
|
|
||||||
@@ -257,30 +188,18 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
|
|||||||
if os.getenv('http_proxy'):
|
if os.getenv('http_proxy'):
|
||||||
args['proxy'] = os.getenv('http_proxy')
|
args['proxy'] = os.getenv('http_proxy')
|
||||||
|
|
||||||
self.bot = MyClient(intents=intents, **args)
|
bot = MyClient(intents=intents, **args)
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
config=config,
|
||||||
|
logger=logger,
|
||||||
|
bot_account_id=bot_account_id,
|
||||||
|
listeners=listeners,
|
||||||
|
bot=bot,
|
||||||
|
)
|
||||||
|
|
||||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||||
msg_to_send, image_files = await self.message_converter.yiri2target(message)
|
pass
|
||||||
|
|
||||||
try:
|
|
||||||
# 获取频道对象
|
|
||||||
channel = self.bot.get_channel(int(target_id))
|
|
||||||
if channel is None:
|
|
||||||
# 如果本地缓存中没有,尝试从API获取
|
|
||||||
channel = await self.bot.fetch_channel(int(target_id))
|
|
||||||
|
|
||||||
args = {
|
|
||||||
'content': msg_to_send,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(image_files) > 0:
|
|
||||||
args['files'] = image_files
|
|
||||||
|
|
||||||
await channel.send(**args)
|
|
||||||
|
|
||||||
except Exception as e:
|
|
||||||
await self.logger.error(f"Discord send_message failed: {e}")
|
|
||||||
raise e
|
|
||||||
|
|
||||||
async def reply_message(
|
async def reply_message(
|
||||||
self,
|
self,
|
||||||
@@ -312,14 +231,18 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
|
|||||||
def register_listener(
|
def register_listener(
|
||||||
self,
|
self,
|
||||||
event_type: typing.Type[platform_events.Event],
|
event_type: typing.Type[platform_events.Event],
|
||||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
callback: typing.Callable[
|
||||||
|
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||||
|
],
|
||||||
):
|
):
|
||||||
self.listeners[event_type] = callback
|
self.listeners[event_type] = callback
|
||||||
|
|
||||||
def unregister_listener(
|
def unregister_listener(
|
||||||
self,
|
self,
|
||||||
event_type: typing.Type[platform_events.Event],
|
event_type: typing.Type[platform_events.Event],
|
||||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
callback: typing.Callable[
|
||||||
|
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||||
|
],
|
||||||
):
|
):
|
||||||
self.listeners.pop(event_type)
|
self.listeners.pop(event_type)
|
||||||
|
|
||||||
|
|||||||
@@ -17,13 +17,13 @@ import aiohttp
|
|||||||
import lark_oapi.ws.exception
|
import lark_oapi.ws.exception
|
||||||
import quart
|
import quart
|
||||||
from lark_oapi.api.im.v1 import *
|
from lark_oapi.api.im.v1 import *
|
||||||
|
import pydantic
|
||||||
|
|
||||||
from .. import adapter
|
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||||
from ...core import app
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
from ..types import message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||||
from ..types import events as platform_events
|
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||||
from ..types import entities as platform_entities
|
import langbot_plugin.api.definition.abstract.platform.event_logger as abstract_platform_logger
|
||||||
from ..logger import EventLogger
|
|
||||||
|
|
||||||
|
|
||||||
class AESCipher(object):
|
class AESCipher(object):
|
||||||
@@ -52,7 +52,7 @@ class AESCipher(object):
|
|||||||
return self.decrypt(enc).decode('utf8')
|
return self.decrypt(enc).decode('utf8')
|
||||||
|
|
||||||
|
|
||||||
class LarkMessageConverter(adapter.MessageConverter):
|
class LarkMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def yiri2target(
|
async def yiri2target(
|
||||||
message_chain: platform_message.MessageChain, api_client: lark_oapi.Client
|
message_chain: platform_message.MessageChain, api_client: lark_oapi.Client
|
||||||
@@ -276,7 +276,7 @@ class LarkMessageConverter(adapter.MessageConverter):
|
|||||||
return platform_message.MessageChain(lb_msg_list)
|
return platform_message.MessageChain(lb_msg_list)
|
||||||
|
|
||||||
|
|
||||||
class LarkEventConverter(adapter.EventConverter):
|
class LarkEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
async def yiri2target(
|
async def yiri2target(
|
||||||
event: platform_events.MessageEvent,
|
event: platform_events.MessageEvent,
|
||||||
@@ -320,39 +320,28 @@ class LarkEventConverter(adapter.EventConverter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class LarkAdapter(adapter.MessagePlatformAdapter):
|
class LarkAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||||
bot: lark_oapi.ws.Client
|
bot: lark_oapi.ws.Client = pydantic.Field(exclude=True)
|
||||||
api_client: lark_oapi.Client
|
api_client: lark_oapi.Client = pydantic.Field(exclude=True)
|
||||||
|
|
||||||
bot_account_id: str # 用于在流水线中识别at是否是本bot,直接以bot_name作为标识
|
|
||||||
lark_tenant_key: str # 飞书企业key
|
|
||||||
|
|
||||||
message_converter: LarkMessageConverter = LarkMessageConverter()
|
message_converter: LarkMessageConverter = LarkMessageConverter()
|
||||||
event_converter: LarkEventConverter = LarkEventConverter()
|
event_converter: LarkEventConverter = LarkEventConverter()
|
||||||
|
|
||||||
listeners: typing.Dict[
|
listeners: typing.Dict[
|
||||||
typing.Type[platform_events.Event],
|
typing.Type[platform_events.Event],
|
||||||
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
|
||||||
]
|
]
|
||||||
|
|
||||||
config: dict
|
quart_app: quart.Quart = pydantic.Field(exclude=True)
|
||||||
quart_app: quart.Quart
|
|
||||||
ap: app.Application
|
|
||||||
|
|
||||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
|
||||||
self.config = config
|
quart_app = quart.Quart(__name__)
|
||||||
self.ap = ap
|
|
||||||
self.logger = logger
|
|
||||||
self.quart_app = quart.Quart(__name__)
|
|
||||||
self.listeners = {}
|
|
||||||
|
|
||||||
@self.quart_app.route('/lark/callback', methods=['POST'])
|
@quart_app.route('/lark/callback', methods=['POST'])
|
||||||
async def lark_callback():
|
async def lark_callback():
|
||||||
try:
|
try:
|
||||||
data = await quart.request.json
|
data = await quart.request.json
|
||||||
|
|
||||||
self.ap.logger.debug(f'Lark callback event: {data}')
|
|
||||||
|
|
||||||
if 'encrypt' in data:
|
if 'encrypt' in data:
|
||||||
cipher = AESCipher(self.config['encrypt-key'])
|
cipher = AESCipher(self.config['encrypt-key'])
|
||||||
data = cipher.decrypt_string(data['encrypt'])
|
data = cipher.decrypt_string(data['encrypt'])
|
||||||
@@ -378,15 +367,15 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
|||||||
if 'im.message.receive_v1' == type:
|
if 'im.message.receive_v1' == type:
|
||||||
try:
|
try:
|
||||||
event = await self.event_converter.target2yiri(p2v1, self.api_client)
|
event = await self.event_converter.target2yiri(p2v1, self.api_client)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
await self.logger.error(f"Error in lark callback: {traceback.format_exc()}")
|
await self.logger.error(f'Error in lark callback: {traceback.format_exc()}')
|
||||||
|
|
||||||
if event.__class__ in self.listeners:
|
if event.__class__ in self.listeners:
|
||||||
await self.listeners[event.__class__](event, self)
|
await self.listeners[event.__class__](event, self)
|
||||||
|
|
||||||
return {'code': 200, 'message': 'ok'}
|
return {'code': 200, 'message': 'ok'}
|
||||||
except Exception as e:
|
except Exception:
|
||||||
await self.logger.error(f"Error in lark callback: {traceback.format_exc()}")
|
await self.logger.error(f'Error in lark callback: {traceback.format_exc()}')
|
||||||
return {'code': 500, 'message': 'error'}
|
return {'code': 500, 'message': 'error'}
|
||||||
|
|
||||||
async def on_message(event: lark_oapi.im.v1.P2ImMessageReceiveV1):
|
async def on_message(event: lark_oapi.im.v1.P2ImMessageReceiveV1):
|
||||||
@@ -401,10 +390,20 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
|||||||
lark_oapi.EventDispatcherHandler.builder('', '').register_p2_im_message_receive_v1(sync_on_message).build()
|
lark_oapi.EventDispatcherHandler.builder('', '').register_p2_im_message_receive_v1(sync_on_message).build()
|
||||||
)
|
)
|
||||||
|
|
||||||
self.bot_account_id = config['bot_name']
|
bot_account_id = config['bot_name']
|
||||||
|
|
||||||
self.bot = lark_oapi.ws.Client(config['app_id'], config['app_secret'], event_handler=event_handler)
|
bot = lark_oapi.ws.Client(config['app_id'], config['app_secret'], event_handler=event_handler)
|
||||||
self.api_client = lark_oapi.Client.builder().app_id(config['app_id']).app_secret(config['app_secret']).build()
|
api_client = lark_oapi.Client.builder().app_id(config['app_id']).app_secret(config['app_secret']).build()
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
config=config,
|
||||||
|
logger=logger,
|
||||||
|
listeners={},
|
||||||
|
quart_app=quart_app,
|
||||||
|
bot=bot,
|
||||||
|
api_client=api_client,
|
||||||
|
bot_account_id=bot_account_id,
|
||||||
|
)
|
||||||
|
|
||||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||||
pass
|
pass
|
||||||
@@ -453,14 +452,18 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
|||||||
def register_listener(
|
def register_listener(
|
||||||
self,
|
self,
|
||||||
event_type: typing.Type[platform_events.Event],
|
event_type: typing.Type[platform_events.Event],
|
||||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
callback: typing.Callable[
|
||||||
|
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||||
|
],
|
||||||
):
|
):
|
||||||
self.listeners[event_type] = callback
|
self.listeners[event_type] = callback
|
||||||
|
|
||||||
def unregister_listener(
|
def unregister_listener(
|
||||||
self,
|
self,
|
||||||
event_type: typing.Type[platform_events.Event],
|
event_type: typing.Type[platform_events.Event],
|
||||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
callback: typing.Callable[
|
||||||
|
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||||
|
],
|
||||||
):
|
):
|
||||||
self.listeners.pop(event_type)
|
self.listeners.pop(event_type)
|
||||||
|
|
||||||
|
|||||||
|
Before Width: | Height: | Size: 25 KiB After Width: | Height: | Size: 25 KiB |
@@ -11,19 +11,19 @@ import threading
|
|||||||
import quart
|
import quart
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
|
||||||
from .. import adapter
|
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||||
from ...core import app
|
from ....core import app
|
||||||
from ..types import message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
from ..types import events as platform_events
|
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||||
from ..types import entities as platform_entities
|
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||||
from ...utils import image
|
from ....utils import image
|
||||||
import xml.etree.ElementTree as ET
|
import xml.etree.ElementTree as ET
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from ..logger import EventLogger
|
from ...logger import EventLogger
|
||||||
|
|
||||||
|
|
||||||
class GewechatMessageConverter(adapter.MessageConverter):
|
class GewechatMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||||
def __init__(self, config: dict):
|
def __init__(self, config: dict):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
@@ -398,7 +398,7 @@ class GewechatMessageConverter(adapter.MessageConverter):
|
|||||||
return from_user_name.endswith('@chatroom')
|
return from_user_name.endswith('@chatroom')
|
||||||
|
|
||||||
|
|
||||||
class GewechatEventConverter(adapter.EventConverter):
|
class GewechatEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||||
def __init__(self, config: dict):
|
def __init__(self, config: dict):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.message_converter = GewechatMessageConverter(config)
|
self.message_converter = GewechatMessageConverter(config)
|
||||||
@@ -458,7 +458,7 @@ class GewechatEventConverter(adapter.EventConverter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
class GeWeChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||||
name: str = 'gewechat' # 定义适配器名称
|
name: str = 'gewechat' # 定义适配器名称
|
||||||
|
|
||||||
bot: gewechat_client.GewechatClient
|
bot: gewechat_client.GewechatClient
|
||||||
@@ -475,7 +475,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
|||||||
|
|
||||||
listeners: typing.Dict[
|
listeners: typing.Dict[
|
||||||
typing.Type[platform_events.Event],
|
typing.Type[platform_events.Event],
|
||||||
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
typing.Callable[[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None],
|
||||||
] = {}
|
] = {}
|
||||||
|
|
||||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
||||||
@@ -491,7 +491,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
|||||||
async def gewechat_callback():
|
async def gewechat_callback():
|
||||||
data = await quart.request.json
|
data = await quart.request.json
|
||||||
# print(json.dumps(data, indent=4, ensure_ascii=False))
|
# print(json.dumps(data, indent=4, ensure_ascii=False))
|
||||||
self.ap.logger.debug(f'Gewechat callback event: {data}')
|
await self.logger.debug(f'Gewechat callback event: {data}')
|
||||||
|
|
||||||
if 'data' in data:
|
if 'data' in data:
|
||||||
data['Data'] = data['data']
|
data['Data'] = data['data']
|
||||||
@@ -601,7 +601,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
|||||||
if handler := handler_map.get(msg['type']):
|
if handler := handler_map.get(msg['type']):
|
||||||
handler(msg)
|
handler(msg)
|
||||||
else:
|
else:
|
||||||
self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}')
|
await self.logger.warning(f'未处理的消息类型: {msg["type"]}')
|
||||||
continue
|
continue
|
||||||
|
|
||||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||||
@@ -625,14 +625,18 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
|||||||
def register_listener(
|
def register_listener(
|
||||||
self,
|
self,
|
||||||
event_type: typing.Type[platform_events.Event],
|
event_type: typing.Type[platform_events.Event],
|
||||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
callback: typing.Callable[
|
||||||
|
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||||
|
],
|
||||||
):
|
):
|
||||||
self.listeners[event_type] = callback
|
self.listeners[event_type] = callback
|
||||||
|
|
||||||
def unregister_listener(
|
def unregister_listener(
|
||||||
self,
|
self,
|
||||||
event_type: typing.Type[platform_events.Event],
|
event_type: typing.Type[platform_events.Event],
|
||||||
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
callback: typing.Callable[
|
||||||
|
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||||
|
],
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@@ -656,9 +660,7 @@ class GeWeChatAdapter(adapter.MessagePlatformAdapter):
|
|||||||
|
|
||||||
self.config['app_id'] = app_id
|
self.config['app_id'] = app_id
|
||||||
|
|
||||||
self.ap.logger.info(f'Gewechat 登录成功,app_id: {app_id}')
|
print(f'Gewechat 登录成功,app_id: {app_id}')
|
||||||
|
|
||||||
self.ap.platform_mgr.write_back_config('gewechat', self, self.config)
|
|
||||||
|
|
||||||
# 获取 nickname
|
# 获取 nickname
|
||||||
profile = self.bot.get_profile(self.config['app_id'])
|
profile = self.bot.get_profile(self.config['app_id'])
|
||||||
|
Before Width: | Height: | Size: 274 KiB After Width: | Height: | Size: 274 KiB |
@@ -9,15 +9,15 @@ import traceback
|
|||||||
import nakuru
|
import nakuru
|
||||||
import nakuru.entities.components as nkc
|
import nakuru.entities.components as nkc
|
||||||
|
|
||||||
from .. import adapter as adapter_model
|
from ....pipeline.longtext.strategies import forward
|
||||||
from ...pipeline.longtext.strategies import forward
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
from ...platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||||
from ...platform.types import entities as platform_entities
|
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||||
from ...platform.types import events as platform_events
|
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||||
from ..logger import EventLogger
|
from ...logger import EventLogger
|
||||||
|
|
||||||
|
|
||||||
class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
class NakuruProjectMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||||
"""消息转换器"""
|
"""消息转换器"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -72,8 +72,9 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
|||||||
content=content_list,
|
content=content_list,
|
||||||
)
|
)
|
||||||
nakuru_forward_node_list.append(nakuru_forward_node)
|
nakuru_forward_node_list.append(nakuru_forward_node)
|
||||||
except Exception as e:
|
except Exception:
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
|
||||||
nakuru_msg_list.append(nakuru_forward_node_list)
|
nakuru_msg_list.append(nakuru_forward_node_list)
|
||||||
@@ -108,7 +109,7 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
|
|||||||
return chain
|
return chain
|
||||||
|
|
||||||
|
|
||||||
class NakuruProjectEventConverter(adapter_model.EventConverter):
|
class NakuruProjectEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||||
"""事件转换器"""
|
"""事件转换器"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -163,7 +164,7 @@ class NakuruProjectEventConverter(adapter_model.EventConverter):
|
|||||||
raise Exception('未支持转换的事件类型: ' + str(event))
|
raise Exception('未支持转换的事件类型: ' + str(event))
|
||||||
|
|
||||||
|
|
||||||
class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
class NakuruAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||||
"""nakuru-project适配器"""
|
"""nakuru-project适配器"""
|
||||||
|
|
||||||
bot: nakuru.CQHTTP
|
bot: nakuru.CQHTTP
|
||||||
@@ -255,13 +256,15 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
|||||||
def register_listener(
|
def register_listener(
|
||||||
self,
|
self,
|
||||||
event_type: typing.Type[platform_events.Event],
|
event_type: typing.Type[platform_events.Event],
|
||||||
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
|
callback: typing.Callable[
|
||||||
|
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||||
|
],
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
source_cls = NakuruProjectEventConverter.yiri2target(event_type)
|
source_cls = NakuruProjectEventConverter.yiri2target(event_type)
|
||||||
|
|
||||||
# 包装函数
|
# 包装函数
|
||||||
async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls):
|
async def listener_wrapper(app: nakuru.CQHTTP, source: source_cls): # type: ignore
|
||||||
await callback(self.event_converter.target2yiri(source), self)
|
await callback(self.event_converter.target2yiri(source), self)
|
||||||
|
|
||||||
# 将包装函数和原函数的对应关系存入列表
|
# 将包装函数和原函数的对应关系存入列表
|
||||||
@@ -276,13 +279,15 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
|||||||
# 注册监听器
|
# 注册监听器
|
||||||
self.bot.receiver(source_cls.__name__)(listener_wrapper)
|
self.bot.receiver(source_cls.__name__)(listener_wrapper)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error in nakuru register_listener: {traceback.format_exc()}")
|
self.logger.error(f'Error in nakuru register_listener: {traceback.format_exc()}')
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def unregister_listener(
|
def unregister_listener(
|
||||||
self,
|
self,
|
||||||
event_type: typing.Type[platform_events.Event],
|
event_type: typing.Type[platform_events.Event],
|
||||||
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
|
callback: typing.Callable[
|
||||||
|
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||||
|
],
|
||||||
):
|
):
|
||||||
nakuru_event_name = self.event_converter.yiri2target(event_type).__name__
|
nakuru_event_name = self.event_converter.yiri2target(event_type).__name__
|
||||||
|
|
||||||
@@ -321,7 +326,6 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
|
|||||||
except Exception:
|
except Exception:
|
||||||
raise Exception('获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确')
|
raise Exception('获取go-cqhttp账号信息失败, 请检查是否已启动go-cqhttp并配置正确')
|
||||||
await self.bot._run()
|
await self.bot._run()
|
||||||
self.ap.logger.info('运行 Nakuru 适配器')
|
|
||||||
while True:
|
while True:
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
@@ -10,14 +10,14 @@ import botpy
|
|||||||
import botpy.message as botpy_message
|
import botpy.message as botpy_message
|
||||||
import botpy.types.message as botpy_message_type
|
import botpy.types.message as botpy_message_type
|
||||||
|
|
||||||
from .. import adapter as adapter_model
|
import langbot_plugin.api.definition.abstract.platform.adapter as abstract_platform_adapter
|
||||||
from ...pipeline.longtext.strategies import forward
|
from ....pipeline.longtext.strategies import forward
|
||||||
from ...core import app
|
from ....core import app
|
||||||
from ...config import manager as cfg_mgr
|
from ....config import manager as cfg_mgr
|
||||||
from ...platform.types import entities as platform_entities
|
import langbot_plugin.api.entities.builtin.platform.entities as platform_entities
|
||||||
from ...platform.types import events as platform_events
|
import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||||
from ...platform.types import message as platform_message
|
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||||
from ..logger import EventLogger
|
from ...logger import EventLogger
|
||||||
|
|
||||||
|
|
||||||
class OfficialGroupMessage(platform_events.GroupMessage):
|
class OfficialGroupMessage(platform_events.GroupMessage):
|
||||||
@@ -133,7 +133,7 @@ class OpenIDMapping(typing.Generic[K, V]):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
class OfficialMessageConverter(adapter_model.MessageConverter):
|
class OfficialMessageConverter(abstract_platform_adapter.AbstractMessageConverter):
|
||||||
"""QQ 官方消息转换器"""
|
"""QQ 官方消息转换器"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -237,7 +237,7 @@ class OfficialMessageConverter(adapter_model.MessageConverter):
|
|||||||
return chain
|
return chain
|
||||||
|
|
||||||
|
|
||||||
class OfficialEventConverter(adapter_model.EventConverter):
|
class OfficialEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||||
"""事件转换器"""
|
"""事件转换器"""
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -333,7 +333,7 @@ class OfficialEventConverter(adapter_model.EventConverter):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
class OfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||||
"""QQ 官方消息适配器"""
|
"""QQ 官方消息适配器"""
|
||||||
|
|
||||||
bot: botpy.Client = None
|
bot: botpy.Client = None
|
||||||
@@ -484,7 +484,9 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
|||||||
def register_listener(
|
def register_listener(
|
||||||
self,
|
self,
|
||||||
event_type: typing.Type[platform_events.Event],
|
event_type: typing.Type[platform_events.Event],
|
||||||
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
|
callback: typing.Callable[
|
||||||
|
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||||
|
],
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
@@ -501,13 +503,15 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
|||||||
for event_handler in event_handler_mapping[event_type]:
|
for event_handler in event_handler_mapping[event_type]:
|
||||||
setattr(self.bot, event_handler, wrapper)
|
setattr(self.bot, event_handler, wrapper)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"Error in qqbotpy callback: {traceback.format_exc()}")
|
self.logger.error(f'Error in qqbotpy callback: {traceback.format_exc()}')
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def unregister_listener(
|
def unregister_listener(
|
||||||
self,
|
self,
|
||||||
event_type: typing.Type[platform_events.Event],
|
event_type: typing.Type[platform_events.Event],
|
||||||
callback: typing.Callable[[platform_events.Event, adapter_model.MessagePlatformAdapter], None],
|
callback: typing.Callable[
|
||||||
|
[platform_events.Event, abstract_platform_adapter.AbstractMessagePlatformAdapter], None
|
||||||
|
],
|
||||||
):
|
):
|
||||||
delattr(self.bot, event_handler_mapping[event_type])
|
delattr(self.bot, event_handler_mapping[event_type])
|
||||||
|
|
||||||
@@ -519,7 +523,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
|||||||
|
|
||||||
self.cfg['ret_coro'] = True
|
self.cfg['ret_coro'] = True
|
||||||
|
|
||||||
self.ap.logger.info('运行 QQ 官方适配器')
|
await self.logger.info('运行 QQ 官方适配器')
|
||||||
await (await self.bot.start(**self.cfg))
|
await (await self.bot.start(**self.cfg))
|
||||||
|
|
||||||
async def kill(self) -> bool:
|
async def kill(self) -> bool:
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user