feat: switch tool entities and format

This commit is contained in:
Junyan Qin
2025-06-15 12:51:51 +08:00
parent c5eeab2fd0
commit 0c2560cafb
55 changed files with 455 additions and 774 deletions

View File

@@ -104,7 +104,7 @@ class QQOfficialClient:
return {'code': 0, 'message': 'success'}
except Exception as e:
await self.logger.error(f"Error in handle_callback_request: {traceback.format_exc()}")
await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
return {'error': str(e)}, 400
async def run_task(self, host: str, port: int, *args, **kwargs):
@@ -168,7 +168,6 @@ class QQOfficialClient:
if not await self.check_access_token():
await self.get_access_token()
url = self.base_url + '/v2/users/' + user_openid + '/messages'
async with httpx.AsyncClient() as client:
headers = {
@@ -193,7 +192,6 @@ class QQOfficialClient:
if not await self.check_access_token():
await self.get_access_token()
url = self.base_url + '/v2/groups/' + group_openid + '/messages'
async with httpx.AsyncClient() as client:
headers = {
@@ -209,7 +207,7 @@ class QQOfficialClient:
if response.status_code == 200:
return
else:
await self.logger.error(f"发送群聊消息失败:{response.json()}")
await self.logger.error(f'发送群聊消息失败:{response.json()}')
raise Exception(response.read().decode())
async def send_channle_group_text_msg(self, channel_id: str, content: str, msg_id: str):
@@ -217,7 +215,6 @@ class QQOfficialClient:
if not await self.check_access_token():
await self.get_access_token()
url = self.base_url + '/channels/' + channel_id + '/messages'
async with httpx.AsyncClient() as client:
headers = {
@@ -240,7 +237,6 @@ class QQOfficialClient:
"""发送频道私聊消息"""
if not await self.check_access_token():
await self.get_access_token()
url = self.base_url + '/dms/' + guild_id + '/messages'
async with httpx.AsyncClient() as client:

View File

@@ -34,7 +34,6 @@ class SlackClient:
if self.bot_user_id and bot_user_id == self.bot_user_id:
return jsonify({'status': 'ok'})
# 处理私信
if data and data.get('event', {}).get('channel_type') in ['im']:
@@ -52,7 +51,7 @@ class SlackClient:
return jsonify({'status': 'ok'})
except Exception as e:
await self.logger.error(f"Error in handle_callback_request: {traceback.format_exc()}")
await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
raise (e)
async def _handle_message(self, event: SlackEvent):
@@ -82,7 +81,7 @@ class SlackClient:
self.bot_user_id = response['message']['bot_id']
return
except Exception as e:
await self.logger.error(f"Error in send_message: {e}")
await self.logger.error(f'Error in send_message: {e}')
raise e
async def send_message_to_one(self, text: str, user_id: str):
@@ -93,7 +92,7 @@ class SlackClient:
return
except Exception as e:
await self.logger.error(f"Error in send_message: {traceback.format_exc()}")
await self.logger.error(f'Error in send_message: {traceback.format_exc()}')
raise e
async def run_task(self, host: str, port: int, *args, **kwargs):

View File

@@ -1 +1,4 @@
from .client import WeChatPadClient
from .client import WeChatPadClient
__all__ = ['WeChatPadClient']

View File

@@ -1,4 +1,4 @@
from libs.wechatpad_api.util.http_util import async_request, post_json
from libs.wechatpad_api.util.http_util import post_json
class ChatRoomApi:
@@ -7,8 +7,6 @@ class ChatRoomApi:
self.token = token
def get_chatroom_member_detail(self, chatroom_name):
params = {
"ChatRoomName": chatroom_name
}
params = {'ChatRoomName': chatroom_name}
url = self.base_url + '/group/GetChatroomMemberDetail'
return post_json(url, token=self.token, data=params)

View File

@@ -1,32 +1,23 @@
from libs.wechatpad_api.util.http_util import async_request, post_json
from libs.wechatpad_api.util.http_util import post_json
import httpx
import base64
class DownloadApi:
def __init__(self, base_url, token):
self.base_url = base_url
self.token = token
def send_download(self, aeskey, file_type, file_url):
json_data = {
"AesKey": aeskey,
"FileType": file_type,
"FileURL": file_url
}
url = self.base_url + "/message/SendCdnDownload"
json_data = {'AesKey': aeskey, 'FileType': file_type, 'FileURL': file_url}
url = self.base_url + '/message/SendCdnDownload'
return post_json(url, token=self.token, data=json_data)
def get_msg_voice(self,buf_id, length, new_msgid):
json_data = {
"Bufid": buf_id,
"Length": length,
"NewMsgId": new_msgid,
"ToUserName": ""
}
url = self.base_url + "/message/GetMsgVoice"
def get_msg_voice(self, buf_id, length, new_msgid):
json_data = {'Bufid': buf_id, 'Length': length, 'NewMsgId': new_msgid, 'ToUserName': ''}
url = self.base_url + '/message/GetMsgVoice'
return post_json(url, token=self.token, data=json_data)
async def download_url_to_base64(self, download_url):
async with httpx.AsyncClient() as client:
response = await client.get(download_url)
@@ -36,4 +27,4 @@ class DownloadApi:
base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式
return base64_str
else:
raise Exception('获取文件失败')
raise Exception('获取文件失败')

View File

@@ -1,11 +1,6 @@
from libs.wechatpad_api.util.http_util import post_json,async_request
from typing import List, Dict, Any, Optional
class FriendApi:
"""联系人API类处理所有与联系人相关的操作"""
def __init__(self, base_url: str, token: str):
self.base_url = base_url
self.token = token

View File

@@ -1,37 +1,34 @@
from libs.wechatpad_api.util.http_util import async_request,post_json,get_json
from libs.wechatpad_api.util.http_util import post_json, get_json
class LoginApi:
def __init__(self, base_url: str, token: str = None, admin_key: str = None):
'''
"""
Args:
base_url: 原始路径
token: token
admin_key: 管理员key
'''
"""
self.base_url = base_url
self.token = token
# self.admin_key = admin_key
def get_token(self, admin_key, day: int=365):
def get_token(self, admin_key, day: int = 365):
# 获取普通token
url = f"{self.base_url}/admin/GenAuthKey1"
json_data = {
"Count": 1,
"Days": day
}
url = f'{self.base_url}/admin/GenAuthKey1'
json_data = {'Count': 1, 'Days': day}
return post_json(base_url=url, token=admin_key, data=json_data)
def get_login_qr(self, Proxy: str = ""):
'''
def get_login_qr(self, Proxy: str = ''):
"""
Args:
Proxy:异地使用时代理
Returns:json数据
'''
"""
"""
{
@@ -49,54 +46,37 @@ class LoginApi:
}
"""
#获取登录二维码
url = f"{self.base_url}/login/GetLoginQrCodeNew"
# 获取登录二维码
url = f'{self.base_url}/login/GetLoginQrCodeNew'
check = False
if Proxy != "":
if Proxy != '':
check = True
json_data = {
"Check": check,
"Proxy": Proxy
}
json_data = {'Check': check, 'Proxy': Proxy}
return post_json(base_url=url, token=self.token, data=json_data)
def get_login_status(self):
# 获取登录状态
url = f'{self.base_url}/login/GetLoginStatus'
return get_json(base_url=url, token=self.token)
def logout(self):
# 退出登录
url = f'{self.base_url}/login/LogOut'
return post_json(base_url=url, token=self.token)
def wake_up_login(self, Proxy: str = ""):
def wake_up_login(self, Proxy: str = ''):
# 唤醒登录
url = f'{self.base_url}/login/WakeUpLogin'
check = False
if Proxy != "":
if Proxy != '':
check = True
json_data = {
"Check": check,
"Proxy": ""
}
json_data = {'Check': check, 'Proxy': ''}
return post_json(base_url=url, token=self.token, data=json_data)
def login(self,admin_key):
def login(self, admin_key):
login_status = self.get_login_status()
if login_status["Code"] == 300 and login_status["Text"] == "你已退出微信":
print("token已经失效重新获取")
if login_status['Code'] == 300 and login_status['Text'] == '你已退出微信':
print('token已经失效重新获取')
token_data = self.get_token(admin_key)
self.token = token_data["Data"][0]
self.token = token_data['Data'][0]

View File

@@ -1,5 +1,4 @@
from libs.wechatpad_api.util.http_util import async_request, post_json
from libs.wechatpad_api.util.http_util import post_json
class MessageApi:
@@ -7,8 +6,8 @@ class MessageApi:
self.base_url = base_url
self.token = token
def post_text(self, to_wxid, content, ats: list= []):
'''
def post_text(self, to_wxid, content, ats: list = []):
"""
Args:
app_id: 微信id
@@ -18,106 +17,64 @@ class MessageApi:
Returns:
'''
url = self.base_url + "/message/SendTextMessage"
"""
url = self.base_url + '/message/SendTextMessage'
"""发送文字消息"""
json_data = {
"MsgItem": [
{
"AtWxIDList": ats,
"ImageContent": "",
"MsgType": 0,
"TextContent": content,
"ToUserName": to_wxid
}
]
}
return post_json(base_url=url, token=self.token, data=json_data)
'MsgItem': [
{'AtWxIDList': ats, 'ImageContent': '', 'MsgType': 0, 'TextContent': content, 'ToUserName': to_wxid}
]
}
return post_json(base_url=url, token=self.token, data=json_data)
def post_image(self, to_wxid, img_url, ats: list= []):
def post_image(self, to_wxid, img_url, ats: list = []):
"""发送图片消息"""
# 这里好像可以尝试发送多个暂时未测试
json_data = {
"MsgItem": [
{
"AtWxIDList": ats,
"ImageContent": img_url,
"MsgType": 0,
"TextContent": '',
"ToUserName": to_wxid
}
'MsgItem': [
{'AtWxIDList': ats, 'ImageContent': img_url, 'MsgType': 0, 'TextContent': '', 'ToUserName': to_wxid}
]
}
url = self.base_url + "/message/SendImageMessage"
url = self.base_url + '/message/SendImageMessage'
return post_json(base_url=url, token=self.token, data=json_data)
def post_voice(self, to_wxid, voice_data, voice_forma, voice_duration):
"""发送语音消息"""
json_data = {
"ToUserName": to_wxid,
"VoiceData": voice_data,
"VoiceFormat": voice_forma,
"VoiceSecond": voice_duration
'ToUserName': to_wxid,
'VoiceData': voice_data,
'VoiceFormat': voice_forma,
'VoiceSecond': voice_duration,
}
url = self.base_url + "/message/SendVoice"
url = self.base_url + '/message/SendVoice'
return post_json(base_url=url, token=self.token, data=json_data)
def post_name_card(self, alias, to_wxid, nick_name, name_card_wxid, flag):
"""发送名片消息"""
param = {
"CardAlias": alias,
"CardFlag": flag,
"CardNickName": nick_name,
"CardWxId": name_card_wxid,
"ToUserName": to_wxid
'CardAlias': alias,
'CardFlag': flag,
'CardNickName': nick_name,
'CardWxId': name_card_wxid,
'ToUserName': to_wxid,
}
url = f"{self.base_url}/message/ShareCardMessage"
url = f'{self.base_url}/message/ShareCardMessage'
return post_json(base_url=url, token=self.token, data=param)
def post_emoji(self, to_wxid, emoji_md5, emoji_size:int=0):
def post_emoji(self, to_wxid, emoji_md5, emoji_size: int = 0):
"""发送emoji消息"""
json_data = {
"EmojiList": [
{
"EmojiMd5": emoji_md5,
"EmojiSize": emoji_size,
"ToUserName": to_wxid
}
]
}
url = f"{self.base_url}/message/SendEmojiMessage"
json_data = {'EmojiList': [{'EmojiMd5': emoji_md5, 'EmojiSize': emoji_size, 'ToUserName': to_wxid}]}
url = f'{self.base_url}/message/SendEmojiMessage'
return post_json(base_url=url, token=self.token, data=json_data)
def post_app_msg(self, to_wxid,xml_data, contenttype:int=0):
def post_app_msg(self, to_wxid, xml_data, contenttype: int = 0):
"""发送appmsg消息"""
json_data = {
"AppList": [
{
"ContentType": contenttype,
"ContentXML": xml_data,
"ToUserName": to_wxid
}
]
}
url = f"{self.base_url}/message/SendAppMessage"
json_data = {'AppList': [{'ContentType': contenttype, 'ContentXML': xml_data, 'ToUserName': to_wxid}]}
url = f'{self.base_url}/message/SendAppMessage'
return post_json(base_url=url, token=self.token, data=json_data)
def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time):
"""撤回消息"""
param = {
"ClientMsgId": msg_id,
"CreateTime": create_time,
"NewMsgId": new_msg_id,
"ToUserName": to_wxid
}
url = f"{self.base_url}/message/RevokeMsg"
return post_json(base_url=url, token=self.token, data=param)
param = {'ClientMsgId': msg_id, 'CreateTime': create_time, 'NewMsgId': new_msg_id, 'ToUserName': to_wxid}
url = f'{self.base_url}/message/RevokeMsg'
return post_json(base_url=url, token=self.token, data=param)

View File

@@ -12,12 +12,9 @@ class UserApi:
return get_json(base_url=url, token=self.token)
def get_qr_code(self, recover:bool=True, style:int=8):
def get_qr_code(self, recover: bool = True, style: int = 8):
"""获取自己的二维码"""
param = {
"Recover": recover,
"Style": style
}
param = {'Recover': recover, 'Style': style}
url = f'{self.base_url}/user/GetMyQRCode'
return post_json(base_url=url, token=self.token, data=param)
@@ -26,12 +23,8 @@ class UserApi:
url = f'{self.base_url}/equipment/GetSafetyInfo'
return post_json(base_url=url, token=self.token)
async def update_head_img(self, head_img_base64):
async def update_head_img(self, head_img_base64):
"""修改头像"""
param = {
"Base64": head_img_base64
}
param = {'Base64': head_img_base64}
url = f'{self.base_url}/user/UploadHeadImage'
return await async_request(base_url=url, token_key=self.token, json=param)
return await async_request(base_url=url, token_key=self.token, json=param)

View File

@@ -1,4 +1,3 @@
from libs.wechatpad_api.api.login import LoginApi
from libs.wechatpad_api.api.friend import FriendApi
from libs.wechatpad_api.api.message import MessageApi
@@ -7,9 +6,6 @@ from libs.wechatpad_api.api.downloadpai import DownloadApi
from libs.wechatpad_api.api.chatroom import ChatRoomApi
class WeChatPadClient:
def __init__(self, base_url, token, logger=None):
self._login_api = LoginApi(base_url, token)
@@ -20,16 +16,16 @@ class WeChatPadClient:
self._chatroom_api = ChatRoomApi(base_url, token)
self.logger = logger
def get_token(self,admin_key, day: int):
'''获取token'''
def get_token(self, admin_key, day: int):
"""获取token"""
return self._login_api.get_token(admin_key, day)
def get_login_qr(self, Proxy:str=""):
def get_login_qr(self, Proxy: str = ''):
"""登录二维码"""
return self._login_api.get_login_qr(Proxy=Proxy)
def awaken_login(self, Proxy:str=""):
'''唤醒登录'''
def awaken_login(self, Proxy: str = ''):
"""唤醒登录"""
return self._login_api.wake_up_login(Proxy=Proxy)
def log_out(self):
@@ -40,59 +36,57 @@ class WeChatPadClient:
"""获取登录状态"""
return self._login_api.get_login_status()
def send_text_message(self, to_wxid, message, ats: list=[]):
def send_text_message(self, to_wxid, message, ats: list = []):
"""发送文本消息"""
return self._message_api.post_text(to_wxid, message, ats)
return self._message_api.post_text(to_wxid, message, ats)
def send_image_message(self, to_wxid, img_url, ats: list=[]):
def send_image_message(self, to_wxid, img_url, ats: list = []):
"""发送图片消息"""
return self._message_api.post_image(to_wxid, img_url, ats)
return self._message_api.post_image(to_wxid, img_url, ats)
def send_voice_message(self, to_wxid, voice_data, voice_forma, voice_duration):
"""发送音频消息"""
return self._message_api.post_voice(to_wxid, voice_data, voice_forma, voice_duration)
return self._message_api.post_voice(to_wxid, voice_data, voice_forma, voice_duration)
def send_app_message(self, to_wxid, app_message, type):
"""发送app消息"""
return self._message_api.post_app_msg(to_wxid, app_message, type)
return self._message_api.post_app_msg(to_wxid, app_message, type)
def send_emoji_message(self, to_wxid, emoji_md5, emoji_size):
"""发送emoji消息"""
return self._message_api.post_emoji(to_wxid,emoji_md5,emoji_size)
return self._message_api.post_emoji(to_wxid, emoji_md5, emoji_size)
def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time):
"""撤回消息"""
return self._message_api.revoke_msg(to_wxid, msg_id, new_msg_id, create_time)
return self._message_api.revoke_msg(to_wxid, msg_id, new_msg_id, create_time)
def get_profile(self):
"""获取用户信息"""
return self._user_api.get_profile()
def get_qr_code(self, recover:bool=True, style:int=8):
def get_qr_code(self, recover: bool = True, style: int = 8):
"""获取用户二维码"""
return self._user_api.get_qr_code(recover=recover, style=style)
return self._user_api.get_qr_code(recover=recover, style=style)
def get_safety_info(self):
"""获取设备信息"""
return self._user_api.get_safety_info()
return self._user_api.get_safety_info()
def update_head_img(self, head_img_base64):
def update_head_img(self, head_img_base64):
"""上传用户头像"""
return self._user_api.update_head_img(head_img_base64)
return self._user_api.update_head_img(head_img_base64)
def cdn_download(self, aeskey, file_type, file_url):
"""cdn下载"""
return self._download_api.send_download( aeskey, file_type, file_url)
return self._download_api.send_download(aeskey, file_type, file_url)
def get_msg_voice(self,buf_id, length, msgid):
def get_msg_voice(self, buf_id, length, msgid):
"""下载语音"""
return self._download_api.get_msg_voice(buf_id, length, msgid)
async def download_base64(self,url):
async def download_base64(self, url):
return await self._download_api.download_url_to_base64(download_url=url)
def get_chatroom_member_detail(self, chatroom_name):
"""查看群成员详情"""
return self._chatroom_api.get_chatroom_member_detail(chatroom_name)

View File

@@ -1,10 +1,9 @@
import requests
import aiohttp
def post_json(base_url, token, data=None):
headers = {
'Content-Type': 'application/json'
}
headers = {'Content-Type': 'application/json'}
url = base_url + f'?key={token}'
@@ -18,14 +17,12 @@ def post_json(base_url, token, data=None):
else:
raise RuntimeError(response.text)
except Exception as e:
print(f"http请求失败, url={url}, exception={e}")
print(f'http请求失败, url={url}, exception={e}')
raise RuntimeError(str(e))
def get_json(base_url, token):
headers = {
'Content-Type': 'application/json'
}
def get_json(base_url, token):
headers = {'Content-Type': 'application/json'}
url = base_url + f'?key={token}'
@@ -39,21 +36,18 @@ def get_json(base_url, token):
else:
raise RuntimeError(response.text)
except Exception as e:
print(f"http请求失败, url={url}, exception={e}")
print(f'http请求失败, url={url}, exception={e}')
raise RuntimeError(str(e))
import aiohttp
import asyncio
async def async_request(
base_url: str,
token_key: str,
method: str = 'POST',
params: dict = None,
# headers: dict = None,
data: dict = None,
json: dict = None
base_url: str,
token_key: str,
method: str = 'POST',
params: dict = None,
# headers: dict = None,
data: dict = None,
json: dict = None,
):
"""
通用异步请求函数
@@ -67,18 +61,11 @@ async def async_request(
:param json: JSON数据
:return: 响应文本
"""
headers = {
'Content-Type': 'application/json'
}
url = f"{base_url}?key={token_key}"
headers = {'Content-Type': 'application/json'}
url = f'{base_url}?key={token_key}'
async with aiohttp.ClientSession() as session:
async with session.request(
method=method,
url=url,
params=params,
headers=headers,
data=data,
json=json
method=method, url=url, params=params, headers=headers, data=data, json=json
) as response:
response.raise_for_status() # 如果状态码不是200抛出异常
result = await response.json()
@@ -89,4 +76,3 @@ async def async_request(
# return await result
# else:
# raise RuntimeError("请求失败",response.text)

View File

@@ -1,31 +1,34 @@
import qrcode
def print_green(text):
print(f"\033[32m{text}\033[0m")
print(f'\033[32m{text}\033[0m')
def print_yellow(text):
print(f"\033[33m{text}\033[0m")
print(f'\033[33m{text}\033[0m')
def print_red(text):
print(f"\033[31m{text}\033[0m")
print(f'\033[31m{text}\033[0m')
def make_and_print_qr(url):
"""生成并打印二维码
Args:
url: 需要生成二维码的URL字符串
Returns:
None
功能:
1. 在终端打印二维码的ASCII图形
2. 同时提供在线二维码生成链接作为备选
"""
print_green("请扫描下方二维码登录")
print_green('请扫描下方二维码登录')
qr = qrcode.QRCode()
qr.add_data(url)
qr.make()
qr.print_ascii(invert=True)
print_green(f"也可以访问下方链接获取二维码:\nhttps://api.qrserver.com/v1/create-qr-code/?data={url}")
print_green(f'也可以访问下方链接获取二维码:\nhttps://api.qrserver.com/v1/create-qr-code/?data={url}')

View File

@@ -57,7 +57,7 @@ class WecomClient:
if 'access_token' in data:
return data['access_token']
else:
await self.logger.error(f"获取accesstoken失败:{response.json()}")
await self.logger.error(f'获取accesstoken失败:{response.json()}')
raise Exception(f'未获取access token: {data}')
async def get_users(self):
@@ -129,7 +129,7 @@ class WecomClient:
response = await client.post(url, json=params)
data = response.json()
except Exception as e:
await self.logger.error(f"发送图片失败:{data}")
await self.logger.error(f'发送图片失败:{data}')
raise Exception('Failed to send image: ' + str(e))
# 企业微信错误码40014和42001代表accesstoken问题
@@ -164,7 +164,7 @@ class WecomClient:
self.access_token = await self.get_access_token(self.secret)
return await self.send_private_msg(user_id, agent_id, content)
if data['errcode'] != 0:
await self.logger.error(f"发送消息失败:{data}")
await self.logger.error(f'发送消息失败:{data}')
raise Exception('Failed to send message: ' + str(data))
async def handle_callback_request(self):
@@ -181,7 +181,7 @@ class WecomClient:
echostr = request.args.get('echostr')
ret, reply_echo_str = wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
if ret != 0:
await self.logger.error("验证失败")
await self.logger.error('验证失败')
raise Exception(f'验证失败,错误码: {ret}')
return reply_echo_str
@@ -189,9 +189,8 @@ class WecomClient:
encrypt_msg = await request.data
ret, xml_msg = wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
if ret != 0:
await self.logger.error("消息解密失败")
await self.logger.error('消息解密失败')
raise Exception(f'消息解密失败,错误码: {ret}')
# 解析消息并处理
message_data = await self.get_message(xml_msg)
@@ -202,7 +201,7 @@ class WecomClient:
return 'success'
except Exception as e:
await self.logger.error(f"Error in handle_callback_request: {traceback.format_exc()}")
await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
return f'Error processing request: {str(e)}', 400
async def run_task(self, host: str, port: int, *args, **kwargs):
@@ -301,7 +300,7 @@ class WecomClient:
except binascii.Error as e:
raise ValueError(f'Invalid base64 string: {str(e)}')
else:
await self.logger.error("Image对象出错")
await self.logger.error('Image对象出错')
raise ValueError('image对象出错')
# 设置 multipart/form-data 格式的文件
@@ -325,7 +324,7 @@ class WecomClient:
self.access_token = await self.get_access_token(self.secret)
media_id = await self.upload_to_work(image)
if data.get('errcode', 0) != 0:
await self.logger.error(f"上传图片失败:{data}")
await self.logger.error(f'上传图片失败:{data}')
raise Exception('failed to upload file')
media_id = data.get('media_id')

View File

@@ -187,7 +187,7 @@ class WecomCSClient:
self.access_token = await self.get_access_token(self.secret)
return await self.send_text_msg(open_kfid, external_userid, msgid, content)
if data['errcode'] != 0:
await self.logger.error(f"发送消息失败:{data}")
await self.logger.error(f'发送消息失败:{data}')
raise Exception('Failed to send message')
return data
@@ -227,7 +227,7 @@ class WecomCSClient:
return 'success'
except Exception as e:
if self.logger:
await self.logger.error(f"Error in handle_callback_request: {traceback.format_exc()}")
await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
else:
traceback.print_exc()
return f'Error processing request: {str(e)}', 400

View File

@@ -7,7 +7,7 @@ from ....core import app
from ....entity.persistence import model as persistence_model
from ....entity.persistence import pipeline as persistence_pipeline
from ....provider.modelmgr import requester as model_requester
from ....provider import entities as llm_entities
from langbot_plugin.api.entities.builtin.provider import message as provider_message
class ModelsService:
@@ -99,7 +99,7 @@ class ModelsService:
await runtime_llm_model.requester.invoke_llm(
query=None,
model=runtime_llm_model,
messages=[llm_entities.Message(role='user', content='Hello, world!')],
messages=[provider_message.Message(role='user', content='Hello, world!')],
funcs=[],
extra_args={},
)

View File

@@ -5,6 +5,7 @@ import typing
from ..core import app, entities as core_entities
from . import entities, operator, errors
from ..utils import importutil
import langbot_plugin.api.entities.builtin.provider.session as provider_session
# 引入所有算子以便注册
from . import operators
@@ -90,7 +91,7 @@ class CommandManager:
self,
command_text: str,
query: core_entities.Query,
session: core_entities.Session,
session: provider_session.Session,
) -> typing.AsyncGenerator[entities.CommandReturn, None]:
"""执行命令"""

View File

@@ -4,6 +4,7 @@ import typing
import pydantic.v1 as pydantic
import langbot_plugin.api.entities.builtin.provider.session as provider_session
from ..core import entities as core_entities
from . import errors
from ..platform.types import message as platform_message
@@ -37,7 +38,7 @@ class ExecuteContext(pydantic.BaseModel):
query: core_entities.Query
"""本次消息的请求对象"""
session: core_entities.Session
session: provider_session.Session
"""本次消息所属的会话对象"""
command_text: str

View File

@@ -2,17 +2,15 @@ from __future__ import annotations
import enum
import typing
import datetime
import asyncio
import pydantic.v1 as pydantic
from ..provider import entities as llm_entities
from ..provider.modelmgr import requester
from ..provider.tools import entities as tools_entities
from ..platform import adapter as msadapter
from ..platform.types import message as platform_message
from ..platform.types import events as platform_events
import langbot_plugin.api.entities.builtin.provider.session as provider_session
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
class LifecycleControlScope(enum.Enum):
@@ -65,7 +63,7 @@ class Query(pydantic.BaseModel):
adapter: msadapter.MessagePlatformAdapter
"""消息平台适配器对象单个app中可能启用了多个消息平台适配器此对象表明发起此query的适配器"""
session: typing.Optional[Session] = None
session: typing.Optional[provider_session.Session] = None
"""会话对象,由前置处理器阶段设置"""
messages: typing.Optional[list[llm_entities.Message]] = []
@@ -80,10 +78,10 @@ class Query(pydantic.BaseModel):
variables: typing.Optional[dict[str, typing.Any]] = None
"""变量由前置处理器阶段设置。在prompt中嵌入或由 Runner 传递到 LLMOps 平台。"""
use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None
use_llm_model_uuid: typing.Optional[str] = None
"""使用的对话模型,由前置处理器阶段设置"""
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
use_funcs: typing.Optional[list[resource_tool.LLMTool]] = None
"""使用的函数,由前置处理器阶段设置"""
resp_messages: (
@@ -95,7 +93,7 @@ class Query(pydantic.BaseModel):
"""回复消息链从resp_messages包装而得"""
# ======= 内部保留 =======
current_stage: typing.Optional['pkg.pipeline.pipelinemgr.StageInstContainer'] = None
current_stage_name: typing.Optional[str] = None
"""当前所处阶段"""
class Config:
@@ -120,57 +118,3 @@ class Query(pydantic.BaseModel):
if self.variables is None:
return {}
return self.variables
class Conversation(pydantic.BaseModel):
"""对话,包含于 Session 中,一个 Session 可以有多个历史 Conversation但只有一个当前使用的 Conversation"""
prompt: llm_entities.Prompt
messages: list[llm_entities.Message]
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
use_llm_model: typing.Optional[requester.RuntimeLLMModel] = None
use_funcs: typing.Optional[list[tools_entities.LLMFunction]]
pipeline_uuid: str
"""流水线UUID。"""
bot_uuid: str
"""机器人UUID。"""
uuid: typing.Optional[str] = None
"""该对话的 uuid在创建时不会自动生成。而是当使用 Dify API 等由外部管理对话信息的服务时,用于绑定外部的会话。具体如何使用,取决于 Runner。"""
class Config:
arbitrary_types_allowed = True
class Session(pydantic.BaseModel):
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""
launcher_type: LauncherTypes
launcher_id: typing.Union[int, str]
sender_id: typing.Optional[typing.Union[int, str]] = 0
use_prompt_name: typing.Optional[str] = 'default'
using_conversation: typing.Optional[Conversation] = None
conversations: typing.Optional[list[Conversation]] = pydantic.Field(default_factory=list)
create_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
update_time: typing.Optional[datetime.datetime] = pydantic.Field(default_factory=datetime.datetime.now)
semaphore: typing.Optional[asyncio.Semaphore] = None
"""当前会话的信号量,用于限制并发"""
class Config:
arbitrary_types_allowed = True

View File

@@ -5,7 +5,7 @@ from ...core import app
from .. import stage, entities
from ...core import entities as core_entities
from . import filter as filter_model, entities as filter_entities
from ...provider import entities as llm_entities
from langbot_plugin.api.entities.builtin.provider import message as provider_message
from ...platform.types import message as platform_message
from ...utils import importutil
@@ -142,7 +142,7 @@ class ContentFilterStage(stage.PipelineStage):
return await self._pre_process(str(query.message_chain).strip(), query)
elif stage_inst_name == 'PostContentFilterStage':
# 仅处理 query.resp_messages[-1].content 是 str 的情况
if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(
if isinstance(query.resp_messages[-1], provider_message.Message) and isinstance(
query.resp_messages[-1].content, str
):
return await self._post_process(query.resp_messages[-1].content, query)

View File

@@ -136,7 +136,7 @@ class RuntimePipeline:
while i < len(self.stage_containers):
stage_container = self.stage_containers[i]
query.current_stage = stage_container # 标记到 Query 对象里
query.current_stage_name = stage_container.inst_name # 标记到 Query 对象里
result = stage_container.inst.process(query, stage_container.inst_name)
@@ -196,7 +196,7 @@ class RuntimePipeline:
await self._execute_from_stage(0, query)
except Exception as e:
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
inst_name = query.current_stage_name if query.current_stage_name else 'unknown'
self.ap.logger.error(f'处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}')
self.ap.logger.error(f'Traceback: {traceback.format_exc()}')
finally:

View File

@@ -4,7 +4,7 @@ import datetime
from .. import stage, entities
from ...core import entities as core_entities
from ...provider import entities as llm_entities
from langbot_plugin.api.entities.builtin.provider import message as provider_message
from ...plugin import events
from ...platform.types import message as platform_message
@@ -49,19 +49,20 @@ class PreProcessor(stage.PipelineStage):
query.bot_uuid,
)
conversation.use_llm_model = llm_model
# 设置query
query.session = session
query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy()
query.use_llm_model = llm_model
query.use_llm_model_uuid = llm_model.model_entity.uuid
if selected_runner == 'local-agent':
query.use_funcs = (
conversation.use_funcs if query.use_llm_model.model_entity.abilities.__contains__('func_call') else None
)
query.use_funcs = []
if llm_model.model_entity.abilities.__contains__('func_call'):
query.use_funcs = await self.ap.tool_mgr.get_all_functions(
plugin_enabled=True,
)
query.variables = {
'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
@@ -73,7 +74,7 @@ class PreProcessor(stage.PipelineStage):
# Check if this model supports vision, if not, remove all images
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
if selected_runner == 'local-agent' and not query.use_llm_model.model_entity.abilities.__contains__('vision'):
if selected_runner == 'local-agent' and not llm_model.model_entity.abilities.__contains__('vision'):
for msg in query.messages:
if isinstance(msg.content, list):
for me in msg.content:
@@ -87,28 +88,24 @@ class PreProcessor(stage.PipelineStage):
for me in query.message_chain:
if isinstance(me, platform_message.Plain):
content_list.append(llm_entities.ContentElement.from_text(me.text))
content_list.append(provider_message.ContentElement.from_text(me.text))
plain_text += me.text
elif isinstance(me, platform_message.Image):
if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__(
'vision'
):
if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
if me.base64 is not None:
content_list.append(llm_entities.ContentElement.from_image_base64(me.base64))
content_list.append(provider_message.ContentElement.from_image_base64(me.base64))
elif isinstance(me, platform_message.Quote) and qoute_msg:
for msg in me.origin:
if isinstance(msg, platform_message.Plain):
content_list.append(llm_entities.ContentElement.from_text(msg.text))
content_list.append(provider_message.ContentElement.from_text(msg.text))
elif isinstance(msg, platform_message.Image):
if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__(
'vision'
):
if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
if msg.base64 is not None:
content_list.append(llm_entities.ContentElement.from_image_base64(msg.base64))
content_list.append(provider_message.ContentElement.from_image_base64(msg.base64))
query.variables['user_message_text'] = plain_text
query.user_message = llm_entities.Message(role='user', content=content_list)
query.user_message = provider_message.Message(role='user', content=content_list)
# =========== 触发事件 PromptPreProcessing
event_ctx = await self.ap.plugin_mgr.emit_event(

View File

@@ -5,7 +5,7 @@ import typing
from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....provider import entities as llm_entities
from langbot_plugin.api.entities.builtin.provider import message as provider_message
from ....plugin import events
from ....platform.types import message as platform_message
@@ -64,7 +64,7 @@ class CommandHandler(handler.MessageHandler):
async for ret in self.ap.cmd_mgr.execute(command_text=command_text, query=query, session=session):
if ret.error is not None:
query.resp_messages.append(
llm_entities.Message(
provider_message.Message(
role='command',
content=str(ret.error),
)
@@ -74,16 +74,16 @@ class CommandHandler(handler.MessageHandler):
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
elif ret.text is not None or ret.image_url is not None:
content: list[llm_entities.ContentElement] = []
content: list[provider_message.ContentElement] = []
if ret.text is not None:
content.append(llm_entities.ContentElement.from_text(ret.text))
content.append(provider_message.ContentElement.from_text(ret.text))
if ret.image_url is not None:
content.append(llm_entities.ContentElement.from_image_url(ret.image_url))
content.append(provider_message.ContentElement.from_image_url(ret.image_url))
query.resp_messages.append(
llm_entities.Message(
provider_message.Message(
role='command',
content=content,
)

View File

@@ -16,7 +16,6 @@ from ..logger import EventLogger
class AiocqhttpMessageConverter(adapter.MessageConverter):
@staticmethod
async def yiri2target(
message_chain: platform_message.MessageChain,
@@ -78,8 +77,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
return msg_list, msg_id, msg_time
@staticmethod
async def target2yiri(message: str, message_id: int = -1,bot=None):
print(message)
async def target2yiri(message: str, message_id: int = -1, bot=None):
message = aiocqhttp.Message(message)
def get_face_name(face_id):
@@ -119,30 +117,28 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
return face_code_dict.get(face_id,'')
async def process_message_data(msg_data, reply_list):
if msg_data["type"] == "image":
image_base64, image_format = await image.qq_image_url_to_base64(msg_data["data"]['url'])
reply_list.append(
platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}'))
if msg_data['type'] == 'image':
image_base64, image_format = await image.qq_image_url_to_base64(msg_data['data']['url'])
reply_list.append(platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}'))
elif msg_data["type"] == "text":
reply_list.append(platform_message.Plain(text=msg_data["data"]["text"]))
elif msg_data['type'] == 'text':
reply_list.append(platform_message.Plain(text=msg_data['data']['text']))
elif msg_data["type"] == "forward": # 这里来应该传入转发消息组暂时传入qoute
for forward_msg_datas in msg_data["data"]["content"]:
for forward_msg_data in forward_msg_datas["message"]:
elif msg_data['type'] == 'forward': # 这里来应该传入转发消息组暂时传入qoute
for forward_msg_datas in msg_data['data']['content']:
for forward_msg_data in forward_msg_datas['message']:
await process_message_data(forward_msg_data, reply_list)
elif msg_data["type"] == "at":
if msg_data["data"]['qq'] == 'all':
elif msg_data['type'] == 'at':
if msg_data['data']['qq'] == 'all':
reply_list.append(platform_message.AtAll())
else:
reply_list.append(
platform_message.At(
target=msg_data["data"]['qq'],
target=msg_data['data']['qq'],
)
)
yiri_msg_list = []
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
@@ -178,14 +174,15 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
# await process_message_data(msg_data, yiri_msg_list)
pass
elif msg.type == 'reply': # 此处处理引用消息传入Qoute
msg_datas = await bot.get_msg(message_id=msg.data["id"])
msg_datas = await bot.get_msg(message_id=msg.data['id'])
for msg_data in msg_datas["message"]:
for msg_data in msg_datas['message']:
await process_message_data(msg_data, reply_list)
reply_msg = platform_message.Quote(message_id=msg.data["id"],sender_id=msg_datas["user_id"],origin=reply_list)
reply_msg = platform_message.Quote(
message_id=msg.data['id'], sender_id=msg_datas['user_id'], origin=reply_list
)
yiri_msg_list.append(reply_msg)
elif msg.type == 'file':
@@ -210,32 +207,19 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
face_id = msg.data['result']
yiri_msg_list.append(platform_message.Face(face_type='dice',face_id=int(face_id),face_name='骰子'))
chain = platform_message.MessageChain(yiri_msg_list)
return chain
class AiocqhttpEventConverter(adapter.EventConverter):
@staticmethod
async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int):
return event.source_platform_object
@staticmethod
async def target2yiri(event: aiocqhttp.Event,bot=None):
yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id,bot)
async def target2yiri(event: aiocqhttp.Event, bot=None):
yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id, bot)
if event.message_type == 'group':
@@ -345,7 +329,7 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
async def on_message(event: aiocqhttp.Event):
self.bot_account_id = event.self_id
try:
return await callback(await self.event_converter.target2yiri(event,self.bot), self)
return await callback(await self.event_converter.target2yiri(event, self.bot), self)
except Exception:
await self.logger.error(f'Error in on_message: {traceback.format_exc()}')
traceback.print_exc()

View File

@@ -378,15 +378,15 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
if 'im.message.receive_v1' == type:
try:
event = await self.event_converter.target2yiri(p2v1, self.api_client)
except Exception as e:
await self.logger.error(f"Error in lark callback: {traceback.format_exc()}")
except Exception:
await self.logger.error(f'Error in lark callback: {traceback.format_exc()}')
if event.__class__ in self.listeners:
await self.listeners[event.__class__](event, self)
return {'code': 200, 'message': 'ok'}
except Exception as e:
await self.logger.error(f"Error in lark callback: {traceback.format_exc()}")
except Exception:
await self.logger.error(f'Error in lark callback: {traceback.format_exc()}')
return {'code': 500, 'message': 'error'}
async def on_message(event: lark_oapi.im.v1.P2ImMessageReceiveV1):

View File

@@ -72,8 +72,9 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
content=content_list,
)
nakuru_forward_node_list.append(nakuru_forward_node)
except Exception as e:
except Exception:
import traceback
traceback.print_exc()
nakuru_msg_list.append(nakuru_forward_node_list)
@@ -276,7 +277,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
# 注册监听器
self.bot.receiver(source_cls.__name__)(listener_wrapper)
except Exception as e:
self.logger.error(f"Error in nakuru register_listener: {traceback.format_exc()}")
self.logger.error(f'Error in nakuru register_listener: {traceback.format_exc()}')
raise e
def unregister_listener(

View File

@@ -125,8 +125,8 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
self.bot_account_id = event.receiver_id
try:
return await callback(await self.event_converter.target2yiri(event), self)
except Exception as e:
await self.logger.error(f"Error in officialaccount callback: {traceback.format_exc()}")
except Exception:
await self.logger.error(f'Error in officialaccount callback: {traceback.format_exc()}')
if event_type == platform_events.FriendMessage:
self.bot.on_message('text')(on_message)

View File

@@ -501,7 +501,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
for event_handler in event_handler_mapping[event_type]:
setattr(self.bot, event_handler, wrapper)
except Exception as e:
self.logger.error(f"Error in qqbotpy callback: {traceback.format_exc()}")
self.logger.error(f'Error in qqbotpy callback: {traceback.format_exc()}')
raise e
def unregister_listener(

View File

@@ -154,10 +154,7 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
raise ParamNotEnoughError('QQ官方机器人缺少相关配置项请查看文档或联系管理员')
self.bot = QQOfficialClient(
app_id=config['appid'],
secret=config['secret'],
token=config['token'],
logger=self.logger
app_id=config['appid'], secret=config['secret'], token=config['token'], logger=self.logger
)
async def reply_message(
@@ -224,8 +221,8 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
self.bot_account_id = 'justbot'
try:
return await callback(await self.event_converter.target2yiri(event), self)
except Exception as e:
await self.logger.error(f"Error in qqofficial callback: {traceback.format_exc()}")
except Exception:
await self.logger.error(f'Error in qqofficial callback: {traceback.format_exc()}')
if event_type == platform_events.FriendMessage:
self.bot.on_message('DIRECT_MESSAGE_CREATE')(on_message)

View File

@@ -104,7 +104,9 @@ class SlackAdapter(adapter.MessagePlatformAdapter):
if missing_keys:
raise ParamNotEnoughError('Slack机器人缺少相关配置项请查看文档或联系管理员')
self.bot = SlackClient(bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger)
self.bot = SlackClient(
bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger
)
async def reply_message(
self,
@@ -139,8 +141,8 @@ class SlackAdapter(adapter.MessagePlatformAdapter):
self.bot_account_id = 'SlackBot'
try:
return await callback(await self.event_converter.target2yiri(event, self.bot), self)
except Exception as e:
await self.logger.error(f"Error in slack callback: {traceback.format_exc()}")
except Exception:
await self.logger.error(f'Error in slack callback: {traceback.format_exc()}')
if event_type == platform_events.FriendMessage:
self.bot.on_message('im')(on_message)

View File

@@ -160,8 +160,8 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
try:
lb_event = await self.event_converter.target2yiri(update, self.bot, self.bot_account_id)
await self.listeners[type(lb_event)](lb_event, self)
except Exception as e:
await self.logger.error(f"Error in telegram callback: {traceback.format_exc()}")
except Exception:
await self.logger.error(f'Error in telegram callback: {traceback.format_exc()}')
self.application = ApplicationBuilder().token(self.config['token']).build()
self.bot = self.application.bot

View File

@@ -1,5 +1,4 @@
import requests
import websockets
import websocket
import json
import time
@@ -10,53 +9,42 @@ from libs.wechatpad_api.client import WeChatPadClient
import typing
import asyncio
import traceback
import time
import re
import base64
import uuid
import json
import os
import copy
import datetime
import threading
import quart
import aiohttp
from .. import adapter
from ...pipeline.longtext.strategies import forward
from ...core import app
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities
from ...utils import image
from ..logger import EventLogger
import xml.etree.ElementTree as ET
from typing import Optional, List, Tuple
from typing import Optional, Tuple
from functools import partial
import logging
class WeChatPadMessageConverter(adapter.MessageConverter):
class WeChatPadMessageConverter(adapter.MessageConverter):
def __init__(self, config: dict):
self.config = config
self.bot = WeChatPadClient(self.config["wechatpad_url"],self.config["token"])
self.logger = logging.getLogger("WeChatPadMessageConverter")
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'])
self.logger = logging.getLogger('WeChatPadMessageConverter')
@staticmethod
async def yiri2target(
message_chain: platform_message.MessageChain
) -> list[dict]:
async def yiri2target(message_chain: platform_message.MessageChain) -> list[dict]:
content_list = []
current_file_path = os.path.abspath(__file__)
_ = os.path.abspath(__file__)
for component in message_chain:
if isinstance(component, platform_message.At):
content_list.append({"type": "at", "target": component.target})
content_list.append({'type': 'at', 'target': component.target})
elif isinstance(component, platform_message.Plain):
content_list.append({"type": "text", "content": component.text})
content_list.append({'type': 'text', 'content': component.text})
elif isinstance(component, platform_message.Image):
if component.url:
async with httpx.AsyncClient() as client:
@@ -68,15 +56,16 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
else:
raise Exception('获取文件失败')
# pass
content_list.append({"type": "image", "image": base64_str})
content_list.append({'type': 'image', 'image': base64_str})
elif component.base64:
content_list.append({"type": "image", "image": component.base64})
content_list.append({'type': 'image', 'image': component.base64})
elif isinstance(component, platform_message.WeChatEmoji):
content_list.append(
{'type': 'WeChatEmoji', 'emoji_md5': component.emoji_md5, 'emoji_size': component.emoji_size})
{'type': 'WeChatEmoji', 'emoji_md5': component.emoji_md5, 'emoji_size': component.emoji_size}
)
elif isinstance(component, platform_message.Voice):
content_list.append({"type": "voice", "data": component.url, "duration": component.length, "forma": 0})
content_list.append({'type': 'voice', 'data': component.url, 'duration': component.length, 'forma': 0})
elif isinstance(component, platform_message.WeChatAppMsg):
content_list.append({'type': 'WeChatAppMsg', 'app_msg': component.app_msg})
elif isinstance(component, platform_message.Forward):
@@ -86,28 +75,23 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
return content_list
async def target2yiri(
self,
message: dict,
bot_account_id: str
) -> platform_message.MessageChain:
async def target2yiri(self, message: dict, bot_account_id: str) -> platform_message.MessageChain:
"""外部消息转平台消息"""
# 数据预处理
message_list = []
ats_bot = False # 是否被@
content = message["content"]["str"]
content = message['content']['str']
content_no_preifx = content # 群消息则去掉前缀
is_group_message = self._is_group_message(message)
if is_group_message:
ats_bot = self._ats_bot(message, bot_account_id)
if "@所有人" in content:
if '@所有人' in content:
message_list.append(platform_message.AtAll())
elif ats_bot:
message_list.append(platform_message.At(target=bot_account_id))
content_no_preifx, _ = self._extract_content_and_sender(content)
msg_type = message["msg_type"]
msg_type = message['msg_type']
# 映射消息类型到处理器方法
handler_map = {
@@ -129,11 +113,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain(message_list)
async def _handler_text(
self,
message: Optional[dict],
content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_text(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理文本消息 (msg_type=1)"""
if message and self._is_group_message(message):
pattern = r'@\S{1,20}'
@@ -141,16 +121,12 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain([platform_message.Plain(content_no_preifx)])
async def _handler_image(
self,
message: Optional[dict],
content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_image(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理图像消息 (msg_type=3)"""
try:
image_xml = content_no_preifx
if not image_xml:
return platform_message.MessageChain([platform_message.Unknown("[图片内容为空]")])
return platform_message.MessageChain([platform_message.Unknown('[图片内容为空]')])
root = ET.fromstring(image_xml)
# 提取img标签的属性
@@ -160,28 +136,22 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
cdnthumburl = img_tag.get('cdnthumburl')
# cdnmidimgurl = img_tag.get('cdnmidimgurl')
image_data = self.bot.cdn_download(aeskey=aeskey, file_type=1, file_url=cdnthumburl)
if image_data["Data"]['FileData'] == '':
if image_data['Data']['FileData'] == '':
image_data = self.bot.cdn_download(aeskey=aeskey, file_type=2, file_url=cdnthumburl)
base64_str = image_data["Data"]['FileData']
base64_str = image_data['Data']['FileData']
# self.logger.info(f"data:image/png;base64,{base64_str}")
elements = [
platform_message.Image(base64=f"data:image/png;base64,{base64_str}"),
platform_message.Image(base64=f'data:image/png;base64,{base64_str}'),
# platform_message.WeChatForwardImage(xml_data=image_xml) # 微信消息转发
]
return platform_message.MessageChain(elements)
except Exception as e:
self.logger.error(f"处理图片失败: {str(e)}")
return platform_message.MessageChain([platform_message.Unknown("[图片处理失败]")])
self.logger.error(f'处理图片失败: {str(e)}')
return platform_message.MessageChain([platform_message.Unknown('[图片处理失败]')])
async def _handler_voice(
self,
message: Optional[dict],
content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_voice(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理语音消息 (msg_type=34)"""
message_List = []
try:
@@ -197,39 +167,33 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
bufid = voicemsg.get('bufid')
length = voicemsg.get('voicelength')
voice_data = self.bot.get_msg_voice(buf_id=str(bufid), length=int(length), msgid=str(new_msg_id))
audio_base64 = voice_data["Data"]['Base64']
audio_base64 = voice_data['Data']['Base64']
# 验证语音数据有效性
if not audio_base64:
message_List.append(platform_message.Unknown(text="[语音内容为空]"))
message_List.append(platform_message.Unknown(text='[语音内容为空]'))
return platform_message.MessageChain(message_List)
# 转换为平台支持的语音格式(如 Silk 格式)
voice_element = platform_message.Voice(
base64=f"data:audio/silk;base64,{audio_base64}"
)
voice_element = platform_message.Voice(base64=f'data:audio/silk;base64,{audio_base64}')
message_List.append(voice_element)
except KeyError as e:
self.logger.error(f"语音数据字段缺失: {str(e)}")
message_List.append(platform_message.Unknown(text="[语音数据解析失败]"))
self.logger.error(f'语音数据字段缺失: {str(e)}')
message_List.append(platform_message.Unknown(text='[语音数据解析失败]'))
except Exception as e:
self.logger.error(f"处理语音消息异常: {str(e)}")
message_List.append(platform_message.Unknown(text="[语音处理失败]"))
self.logger.error(f'处理语音消息异常: {str(e)}')
message_List.append(platform_message.Unknown(text='[语音处理失败]'))
return platform_message.MessageChain(message_List)
async def _handler_compound(
self,
message: Optional[dict],
content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_compound(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理复合消息 (msg_type=49),根据子类型分派"""
try:
xml_data = ET.fromstring(content_no_preifx)
appmsg_data = xml_data.find('.//appmsg')
if appmsg_data:
data_type = appmsg_data.findtext('.//type', "")
data_type = appmsg_data.findtext('.//type', '')
# 二次分派处理器
sub_handler_map = {
'57': self._handler_compound_quote,
@@ -238,9 +202,9 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
'74': self._handler_compound_file,
'33': self._handler_compound_mini_program,
'36': self._handler_compound_mini_program,
'2000': partial(self._handler_compound_unsupported, text="[转账消息]"),
'2001': partial(self._handler_compound_unsupported, text="[红包消息]"),
'51': partial(self._handler_compound_unsupported, text="[视频号消息]"),
'2000': partial(self._handler_compound_unsupported, text='[转账消息]'),
'2001': partial(self._handler_compound_unsupported, text='[红包消息]'),
'51': partial(self._handler_compound_unsupported, text='[视频号消息]'),
}
handler = sub_handler_map.get(data_type, self._handler_compound_unsupported)
@@ -251,56 +215,54 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
else:
return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)])
except Exception as e:
self.logger.error(f"解析复合消息失败: {str(e)}")
self.logger.error(f'解析复合消息失败: {str(e)}')
return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)])
async def _handler_compound_quote(
self,
message: Optional[dict],
xml_data: ET.Element
self, message: Optional[dict], xml_data: ET.Element
) -> platform_message.MessageChain:
"""处理引用消息 (data_type=57)"""
message_list = []
# self.logger.info("_handler_compound_quote", ET.tostring(xml_data, encoding='unicode'))
# self.logger.info("_handler_compound_quote", ET.tostring(xml_data, encoding='unicode'))
appmsg_data = xml_data.find('.//appmsg')
quote_data = "" # 引用原文
quote_data = '' # 引用原文
quote_id = None # 引用消息的原发送者
tousername = None # 接收方: 所属微信的wxid
user_data = "" # 用户消息
user_data = '' # 用户消息
sender_id = xml_data.findtext('.//fromusername') # 发送方:单聊用户/群member
# 引用消息转发
if appmsg_data:
user_data = appmsg_data.findtext('.//title') or ""
user_data = appmsg_data.findtext('.//title') or ''
quote_data = appmsg_data.find('.//refermsg').findtext('.//content')
quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr')
message_list.append(
platform_message.WeChatAppMsg(
app_msg=ET.tostring(appmsg_data, encoding='unicode'))
)
message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg_data, encoding='unicode')))
if message:
tousername = message['to_user_name']["str"]
tousername = message['to_user_name']['str']
_ = tousername
_ = quote_id
if quote_data:
quote_data_message_list = platform_message.MessageChain()
# 文本消息
try:
if "<msg>" not in quote_data:
if '<msg>' not in quote_data:
quote_data_message_list.append(platform_message.Plain(quote_data))
else:
# 引用消息展开
quote_data_xml = ET.fromstring(quote_data)
if quote_data_xml.find("img"):
if quote_data_xml.find('img'):
quote_data_message_list.extend(await self._handler_image(None, quote_data))
elif quote_data_xml.find("voicemsg"):
elif quote_data_xml.find('voicemsg'):
quote_data_message_list.extend(await self._handler_voice(None, quote_data))
elif quote_data_xml.find("videomsg"):
elif quote_data_xml.find('videomsg'):
quote_data_message_list.extend(await self._handler_default(None, quote_data)) # 先不处理
else:
# appmsg
quote_data_message_list.extend(await self._handler_compound(None, quote_data))
except Exception as e:
self.logger.error(f"处理引用消息异常 expcetion:{e}")
self.logger.error(f'处理引用消息异常 expcetion:{e}')
quote_data_message_list.append(platform_message.Plain(quote_data))
message_list.append(
platform_message.Quote(
@@ -315,11 +277,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain(message_list)
async def _handler_compound_file(
self,
message: dict,
xml_data: ET.Element
) -> platform_message.MessageChain:
async def _handler_compound_file(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain:
"""处理文件消息 (data_type=6)"""
file_data = xml_data.find('.//appmsg')
@@ -357,11 +315,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
platform_message.WeChatForwardFile(xml_data=xml_data_str)
])
async def _handler_compound_link(
self,
message: dict,
xml_data: ET.Element
) -> platform_message.MessageChain:
async def _handler_compound_link(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain:
"""处理链接消息(如公众号文章、外部网页)"""
message_list = []
try:
@@ -374,56 +328,38 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
link_title=appmsg.findtext('title', ''),
link_desc=appmsg.findtext('des', ''),
link_url=appmsg.findtext('url', ''),
link_thumb_url=appmsg.findtext("thumburl", '') # 这个字段拿不到
link_thumb_url=appmsg.findtext('thumburl', ''), # 这个字段拿不到
)
)
# 还没有发链接的接口, 暂时还需要自己构造appmsg, 先用WeChatAppMsg。
message_list.append(
platform_message.WeChatAppMsg(
app_msg=ET.tostring(appmsg, encoding='unicode')
)
)
message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg, encoding='unicode')))
except Exception as e:
self.logger.error(f"解析链接消息失败: {str(e)}")
self.logger.error(f'解析链接消息失败: {str(e)}')
return platform_message.MessageChain(message_list)
async def _handler_compound_mini_program(
self,
message: dict,
xml_data: ET.Element
self, message: dict, xml_data: ET.Element
) -> platform_message.MessageChain:
"""处理小程序消息(如小程序卡片、服务通知)"""
xml_data_str = ET.tostring(xml_data, encoding='unicode')
return platform_message.MessageChain([
platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)
])
return platform_message.MessageChain([platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)])
async def _handler_default(
self,
message: Optional[dict],
content_no_preifx: str
) -> platform_message.MessageChain:
async def _handler_default(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
"""处理未知消息类型"""
if message:
msg_type = message["msg_type"]
msg_type = message['msg_type']
else:
msg_type = ""
return platform_message.MessageChain([
platform_message.Unknown(text=f"[未知消息类型 msg_type:{msg_type}]")
])
msg_type = ''
return platform_message.MessageChain([platform_message.Unknown(text=f'[未知消息类型 msg_type:{msg_type}]')])
def _handler_compound_unsupported(
self,
message: dict,
xml_data: str,
text: Optional[str] = None
self, message: dict, xml_data: str, text: Optional[str] = None
) -> platform_message.MessageChain:
"""处理未支持复合消息类型(msg_type=49)子类型"""
if not text:
text = f"[xml_data={xml_data}]"
text = f'[xml_data={xml_data}]'
content_list = []
content_list.append(
platform_message.Unknown(text=f"[处理未支持复合消息类型[msg_type=49]|{text}"))
content_list.append(platform_message.Unknown(text=f'[处理未支持复合消息类型[msg_type=49]|{text}'))
return platform_message.MessageChain(content_list)
@@ -432,7 +368,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
ats_bot = False
try:
to_user_name = message['to_user_name']['str'] # 接收方: 所属微信的wxid
raw_content = message["content"]["str"] # 原始消息内容
raw_content = message['content']['str'] # 原始消息内容
content_no_prefix, _ = self._extract_content_and_sender(raw_content)
# 直接艾特机器人这个有bug当被引用的消息里面有@bot,会套娃
# ats_bot = ats_bot or (f"@{bot_account_id}" in content_no_prefix)
@@ -443,7 +379,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
msg_source = message.get('msg_source', '') or ''
if len(msg_source) > 0:
msg_source_data = ET.fromstring(msg_source)
at_user_list = msg_source_data.findtext("atuserlist") or ""
at_user_list = msg_source_data.findtext('atuserlist') or ''
ats_bot = ats_bot or (to_user_name in at_user_list)
# 引用bot
if message.get('msg_type', 0) == 49:
@@ -454,7 +390,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') # 引用消息的原发送者
ats_bot = ats_bot or (quote_id == tousername)
except Exception as e:
self.logger.error(f"_ats_bot got except: {e}")
self.logger.error(f'_ats_bot got except: {e}')
finally:
return ats_bot
@@ -463,47 +399,41 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
try:
# 检查消息开头,如果有 wxid_sbitaz0mt65n22:\n 则删掉
# add: 有些用户的wxid不是上述格式。换成user_name:
regex = re.compile(r"^[a-zA-Z0-9_\-]{5,20}:")
line_split = raw_content.split("\n")
regex = re.compile(r'^[a-zA-Z0-9_\-]{5,20}:')
line_split = raw_content.split('\n')
if len(line_split) > 0 and regex.match(line_split[0]):
raw_content = "\n".join(line_split[1:])
sender_id = line_split[0].strip(":")
raw_content = '\n'.join(line_split[1:])
sender_id = line_split[0].strip(':')
return raw_content, sender_id
except Exception as e:
self.logger.error(f"_extract_content_and_sender got except: {e}")
self.logger.error(f'_extract_content_and_sender got except: {e}')
finally:
return raw_content, None
# 是否是群消息
def _is_group_message(self, message: dict) -> bool:
from_user_name = message['from_user_name']['str']
return from_user_name.endswith("@chatroom")
return from_user_name.endswith('@chatroom')
class WeChatPadEventConverter(adapter.EventConverter):
def __init__(self, config: dict):
self.config = config
self.message_converter = WeChatPadMessageConverter(config)
self.logger = logging.getLogger("WeChatPadEventConverter")
self.logger = logging.getLogger('WeChatPadEventConverter')
@staticmethod
async def yiri2target(
event: platform_events.MessageEvent
) -> dict:
async def yiri2target(event: platform_events.MessageEvent) -> dict:
pass
async def target2yiri(
self,
event: dict,
bot_account_id: str
) -> platform_events.MessageEvent:
async def target2yiri(self, event: dict, bot_account_id: str) -> platform_events.MessageEvent:
# 排除公众号以及微信团队消息
if event['from_user_name']['str'].startswith('gh_') \
or event['from_user_name']['str']=='weixin'\
or event['from_user_name']['str'] == "newsapp"\
or event['from_user_name']['str'] == self.config["wxid"]:
if (
event['from_user_name']['str'].startswith('gh_')
or event['from_user_name']['str'] == 'weixin'
or event['from_user_name']['str'] == 'newsapp'
or event['from_user_name']['str'] == self.config['wxid']
):
return None
message_chain = await self.message_converter.target2yiri(copy.deepcopy(event), bot_account_id)
@@ -512,7 +442,7 @@ class WeChatPadEventConverter(adapter.EventConverter):
if '@chatroom' in event['from_user_name']['str']:
# 找出开头的 wxid_ 字符串,以:结尾
sender_wxid = event['content']['str'].split(":")[0]
sender_wxid = event['content']['str'].split(':')[0]
return platform_events.GroupMessage(
sender=platform_entities.GroupMember(
@@ -524,13 +454,13 @@ class WeChatPadEventConverter(adapter.EventConverter):
name=event['from_user_name']['str'],
permission=platform_entities.Permission.Member,
),
special_title="",
special_title='',
join_timestamp=0,
last_speak_timestamp=0,
mute_time_remaining=0,
),
message_chain=message_chain,
time=event["create_time"],
time=event['create_time'],
source_platform_object=event,
)
else:
@@ -541,13 +471,13 @@ class WeChatPadEventConverter(adapter.EventConverter):
remark='',
),
message_chain=message_chain,
time=event["create_time"],
time=event['create_time'],
source_platform_object=event,
)
class WeChatPadAdapter(adapter.MessagePlatformAdapter):
name: str = "WeChatPad" # 定义适配器名称
name: str = 'WeChatPad' # 定义适配器名称
bot: WeChatPadClient
quart_app: quart.Quart
@@ -580,27 +510,21 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
# self.ap.logger.debug(f"Gewechat callback event: {data}")
# print(data)
try:
event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id)
except Exception as e:
await self.logger.error(f"Error in wechatpad callback: {traceback.format_exc()}")
except Exception:
await self.logger.error(f'Error in wechatpad callback: {traceback.format_exc()}')
if event.__class__ in self.listeners:
await self.listeners[event.__class__](event, self)
return 'ok'
async def _handle_message(
self,
message: platform_message.MessageChain,
target_id: str
):
async def _handle_message(self, message: platform_message.MessageChain, target_id: str):
"""统一消息处理核心逻辑"""
content_list = await self.message_converter.yiri2target(message)
# print(content_list)
at_targets = [item["target"] for item in content_list if item["type"] == "at"]
at_targets = [item['target'] for item in content_list if item['type'] == 'at']
# print(at_targets)
# 处理@逻辑
at_targets = at_targets or []
@@ -608,7 +532,7 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
if at_targets:
member_info = self.bot.get_chatroom_member_detail(
target_id,
)["Data"]["member_data"]["chatroom_member_list"]
)['Data']['member_data']['chatroom_member_list']
# 处理消息组件
for msg in content_list:
@@ -616,63 +540,51 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
if msg['type'] == 'text' and at_targets:
at_nick_name_list = []
for member in member_info:
if member["user_name"] in at_targets:
if member['user_name'] in at_targets:
at_nick_name_list.append(f'@{member["nick_name"]}')
msg['content'] = f'{" ".join(at_nick_name_list)} {msg["content"]}'
# 统一消息派发
handler_map = {
'text': lambda msg: self.bot.send_text_message(
to_wxid=target_id,
message=msg['content'],
ats=at_targets
to_wxid=target_id, message=msg['content'], ats=at_targets
),
'image': lambda msg: self.bot.send_image_message(
to_wxid=target_id,
img_url=msg["image"],
ats = at_targets
to_wxid=target_id, img_url=msg['image'], ats=at_targets
),
'WeChatEmoji': lambda msg: self.bot.send_emoji_message(
to_wxid=target_id,
emoji_md5=msg['emoji_md5'],
emoji_size=msg['emoji_size']
to_wxid=target_id, emoji_md5=msg['emoji_md5'], emoji_size=msg['emoji_size']
),
'voice': lambda msg: self.bot.send_voice_message(
to_wxid=target_id,
voice_data=msg['data'],
voice_duration=msg["duration"],
voice_forma=msg["forma"],
voice_duration=msg['duration'],
voice_forma=msg['forma'],
),
'WeChatAppMsg': lambda msg: self.bot.send_app_message(
to_wxid=target_id,
app_message=msg['app_msg'],
type=0,
),
'at': lambda msg: None
'at': lambda msg: None,
}
if handler := handler_map.get(msg['type']):
handler(msg)
# self.ap.logger.warning(f"未处理的消息类型: {ret}")
else:
self.ap.logger.warning(f"未处理的消息类型: {msg['type']}")
self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}')
continue
async def send_message(
self,
target_type: str,
target_id: str,
message: platform_message.MessageChain
):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
"""主动发送消息"""
return await self._handle_message(message, target_id)
async def reply_message(
self,
message_source: platform_events.MessageEvent,
message: platform_message.MessageChain,
quote_origin: bool = False
self,
message_source: platform_events.MessageEvent,
message: platform_message.MessageChain,
quote_origin: bool = False,
):
"""回复消息"""
if message_source.source_platform_object:
@@ -683,58 +595,49 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
pass
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None]
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
self.listeners[event_type] = callback
def unregister_listener(
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None]
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
):
pass
async def run_async(self):
if not self.config["admin_key"] and not self.config["token"]:
raise RuntimeError("无wechatpad管理密匙请填入配置文件后重启")
if not self.config['admin_key'] and not self.config['token']:
raise RuntimeError('无wechatpad管理密匙请填入配置文件后重启')
else:
if self.config["token"]:
self.bot = WeChatPadClient(
self.config['wechatpad_url'],
self.config["token"]
)
if self.config['token']:
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'])
data = self.bot.get_login_status()
self.ap.logger.info(data)
if data["Code"] == 300 and data["Text"] == "你已退出微信":
if data['Code'] == 300 and data['Text'] == '你已退出微信':
response = requests.post(
f"{self.config['wechatpad_url']}/admin/GenAuthKey1?key={self.config['admin_key']}",
json={"Count": 1, "Days": 365}
f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}',
json={'Count': 1, 'Days': 365},
)
if response.status_code != 200:
raise Exception(f"获取token失败: {response.text}")
self.config["token"] = response.json()["Data"][0]
raise Exception(f'获取token失败: {response.text}')
self.config['token'] = response.json()['Data'][0]
elif not self.config["token"]:
elif not self.config['token']:
response = requests.post(
f"{self.config['wechatpad_url']}/admin/GenAuthKey1?key={self.config['admin_key']}",
json={"Count": 1, "Days": 365}
f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}',
json={'Count': 1, 'Days': 365},
)
if response.status_code != 200:
raise Exception(f"获取token失败: {response.text}")
self.config["token"] = response.json()["Data"][0]
raise Exception(f'获取token失败: {response.text}')
self.config['token'] = response.json()['Data'][0]
self.bot = WeChatPadClient(
self.config['wechatpad_url'],
self.config["token"],
logger=self.logger
)
self.ap.logger.info(self.config["token"])
self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'], logger=self.logger)
self.ap.logger.info(self.config['token'])
thread_1 = threading.Event()
def wechat_login_process():
# 不登录这些先注释掉避免登陆态尝试拉qrcode。
# login_data =self.bot.get_login_qr()
@@ -742,67 +645,54 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
# url = login_data['Data']["QrCodeUrl"]
# self.ap.logger.info(login_data)
profile =self.bot.get_profile()
profile = self.bot.get_profile()
self.ap.logger.info(profile)
self.bot_account_id = profile["Data"]["userInfo"]["nickName"]["str"]
self.config["wxid"] = profile["Data"]["userInfo"]["userName"]["str"]
self.bot_account_id = profile['Data']['userInfo']['nickName']['str']
self.config['wxid'] = profile['Data']['userInfo']['userName']['str']
thread_1.set()
# asyncio.create_task(wechat_login_process)
threading.Thread(target=wechat_login_process).start()
def connect_websocket_sync() -> None:
thread_1.wait()
uri = f"{self.config['wechatpad_ws']}/GetSyncMsg?key={self.config['token']}"
self.ap.logger.info(f"Connecting to WebSocket: {uri}")
uri = f'{self.config["wechatpad_ws"]}/GetSyncMsg?key={self.config["token"]}'
self.ap.logger.info(f'Connecting to WebSocket: {uri}')
def on_message(ws, message):
try:
data = json.loads(message)
self.ap.logger.debug(f"Received message: {data}")
self.ap.logger.debug(f'Received message: {data}')
# 这里需要确保ws_message是同步的或者使用asyncio.run调用异步方法
asyncio.run(self.ws_message(data))
except json.JSONDecodeError:
self.ap.logger.error(f"Non-JSON message: {message[:100]}...")
self.ap.logger.error(f'Non-JSON message: {message[:100]}...')
def on_error(ws, error):
self.ap.logger.error(f"WebSocket error: {str(error)[:200]}")
self.ap.logger.error(f'WebSocket error: {str(error)[:200]}')
def on_close(ws, close_status_code, close_msg):
self.ap.logger.info("WebSocket closed, reconnecting...")
self.ap.logger.info('WebSocket closed, reconnecting...')
time.sleep(5)
connect_websocket_sync() # 自动重连
def on_open(ws):
self.ap.logger.info("WebSocket connected successfully!")
self.ap.logger.info('WebSocket connected successfully!')
ws = websocket.WebSocketApp(
uri,
on_message=on_message,
on_error=on_error,
on_close=on_close,
on_open=on_open
)
ws.run_forever(
ping_interval=60,
ping_timeout=20
uri, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open
)
ws.run_forever(ping_interval=60, ping_timeout=20)
# 直接调用同步版本(会阻塞)
# connect_websocket_sync()
# 这行代码会在WebSocket连接断开后才会执行
# self.ap.logger.info("WebSocket client thread started")
thread = threading.Thread(
target=connect_websocket_sync,
name="WebSocketClientThread",
daemon=True
)
thread = threading.Thread(target=connect_websocket_sync, name='WebSocketClientThread', daemon=True)
thread.start()
self.ap.logger.info("WebSocket client thread started")
self.ap.logger.info('WebSocket client thread started')
async def kill(self) -> bool:
pass

View File

@@ -157,7 +157,7 @@ class WecomAdapter(adapter.MessagePlatformAdapter):
token=config['token'],
EncodingAESKey=config['EncodingAESKey'],
contacts_secret=config['contacts_secret'],
logger=self.logger
logger=self.logger,
)
async def reply_message(
@@ -201,8 +201,8 @@ class WecomAdapter(adapter.MessagePlatformAdapter):
self.bot_account_id = event.receiver_id
try:
return await callback(await self.event_converter.target2yiri(event), self)
except Exception as e:
await self.logger.error(f"Error in wecom callback: {traceback.format_exc()}")
except Exception:
await self.logger.error(f'Error in wecom callback: {traceback.format_exc()}')
if event_type == platform_events.FriendMessage:
self.bot.on_message('text')(on_message)

View File

@@ -145,7 +145,7 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter):
secret=config['secret'],
token=config['token'],
EncodingAESKey=config['EncodingAESKey'],
logger=self.logger
logger=self.logger,
)
async def reply_message(
@@ -178,8 +178,8 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter):
self.bot_account_id = event.receiver_id
try:
return await callback(await self.event_converter.target2yiri(event), self)
except Exception as e:
await self.logger.error(f"Error in wecomcs callback: {traceback.format_exc()}")
except Exception:
await self.logger.error(f'Error in wecomcs callback: {traceback.format_exc()}')
if event_type == platform_events.FriendMessage:
self.bot.on_message('text')(on_message)

View File

@@ -7,6 +7,7 @@ import pydantic.v1 as pydantic
from ..core import entities as core_entities
from ..provider import entities as llm_entities
from ..platform.types import message as platform_message
import langbot_plugin.api.entities.builtin.provider.session as provider_session
class BaseEventModel(pydantic.BaseModel):
@@ -139,7 +140,7 @@ class NormalMessageResponded(BaseEventModel):
sender_id: typing.Union[int, str]
session: core_entities.Session
session: provider_session.Session
"""会话对象"""
prefix: str

View File

@@ -73,3 +73,18 @@ class RuntimeConnectionHandler(handler.Handler):
)
return result['plugins']
async def emit_event(
self,
event_context: dict[str, Any],
) -> dict[str, Any]:
"""Emit event"""
result = await self.call_action(
LangBotToRuntimeAction.EMIT_EVENT,
{
'event_context': event_context,
},
timeout=10,
)
return result['event_context']

View File

@@ -7,7 +7,7 @@ import traceback
from .. import loader, events, context, models
from ...core import entities as core_entities
from ...provider.tools import entities as tools_entities
from langbot_plugin.api.entities.builtin.resource import tool as resource_tool
from ...utils import funcschema
from ...discover import engine as discover_engine
@@ -101,7 +101,7 @@ class PluginLoader(loader.PluginLoader):
async def handler(plugin: context.BasePlugin, query: core_entities.Query, *args, **kwargs):
return func(*args, **kwargs)
llm_function = tools_entities.LLMFunction(
llm_function = resource_tool.LLMTool(
name=function_name,
human_desc='',
description=function_schema['description'],
@@ -147,7 +147,7 @@ class PluginLoader(loader.PluginLoader):
function_schema = funcschema.get_func_schema(func)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
llm_function = tools_entities.LLMFunction(
llm_function = resource_tool.LLMTool(
name=function_name,
human_desc='',
description=function_schema['description'],

View File

@@ -8,7 +8,7 @@ from ...core import app
from .. import context, events
from .. import loader
from ...utils import funcschema
from ...provider.tools import entities as tools_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
class PluginManifestLoader(loader.PluginLoader):
@@ -41,7 +41,7 @@ class PluginManifestLoader(loader.PluginLoader):
function_schema = funcschema.get_func_schema(func)
function_name = self._current_container.plugin_name + '-' + (func.__name__ if name is None else name)
llm_function = tools_entities.LLMFunction(
llm_function = resource_tool.LLMTool(
name=function_name,
human_desc='',
description=function_schema['description'],

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import sqlalchemy
import traceback
from . import entities, requester
from . import requester
from ...core import app
from ...discover import engine
from . import token
@@ -16,14 +16,6 @@ FETCH_MODEL_LIST_URL = 'https://api.qchatgpt.rockchin.top/api/v2/fetch/model_lis
class ModelManager:
"""模型管理器"""
model_list: list[entities.LLMModelInfo] # deprecated
requesters: dict[str, requester.LLMAPIRequester] # deprecated
token_mgrs: dict[str, token.TokenManager] # deprecated
# ====== 4.0 ======
ap: app.Application
llm_models: list[requester.RuntimeLLMModel]
@@ -34,9 +26,6 @@ class ModelManager:
def __init__(self, ap: app.Application):
self.ap = ap
self.model_list = []
self.requesters = {}
self.token_mgrs = {}
self.llm_models = []
self.requester_components = []
self.requester_dict = {}
@@ -109,14 +98,7 @@ class ModelManager:
runtime_llm_model = await self.init_runtime_llm_model(model_info)
self.llm_models.append(runtime_llm_model)
async def get_model_by_name(self, name: str) -> entities.LLMModelInfo: # deprecated
"""通过名称获取模型"""
for model in self.model_list:
if model.name == name:
return model
raise ValueError(f'无法确定模型 {name} 的信息')
async def get_model_by_uuid(self, uuid: str) -> entities.LLMModelInfo:
async def get_model_by_uuid(self, uuid: str) -> requester.RuntimeLLMModel:
"""通过uuid获取模型"""
for model in self.llm_models:
if model.model_entity.uuid == uuid:

View File

@@ -6,8 +6,8 @@ import typing
from ...core import app
from ...core import entities as core_entities
from .. import entities as llm_entities
from ..tools import entities as tools_entities
from ...entity.persistence import model as persistence_model
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
from . import token
@@ -59,7 +59,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
query: core_entities.Query,
model: RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
funcs: typing.List[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
"""调用API

View File

@@ -11,8 +11,8 @@ from .. import errors, requester
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...tools import entities as tools_entities
from ....utils import image
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
class AnthropicMessages(requester.LLMAPIRequester):
@@ -51,7 +51,7 @@ class AnthropicMessages(requester.LLMAPIRequester):
query: core_entities.Query,
model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
funcs: typing.List[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = model.token_mgr.get_token()

View File

@@ -10,7 +10,7 @@ import httpx
from .. import errors, requester
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...tools import entities as tools_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
class OpenAIChatCompletions(requester.LLMAPIRequester):
@@ -63,7 +63,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
query: core_entities.Query,
req_messages: list[dict],
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
use_funcs: list[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
@@ -104,7 +104,7 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
query: core_entities.Query,
model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
funcs: typing.List[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行

View File

@@ -6,7 +6,7 @@ from . import chatcmpl
from .. import errors, requester
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...tools import entities as tools_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
@@ -22,7 +22,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
query: core_entities.Query,
req_messages: list[dict],
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
use_funcs: list[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()

View File

@@ -7,7 +7,7 @@ from . import chatcmpl
from .. import requester
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...tools import entities as tools_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
@@ -23,7 +23,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
query: core_entities.Query,
req_messages: list[dict],
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
use_funcs: list[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()

View File

@@ -11,7 +11,7 @@ import httpx
from .. import entities, errors, requester
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...tools import entities as tools_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
class ModelScopeChatCompletions(requester.LLMAPIRequester):
@@ -128,7 +128,7 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
query: core_entities.Query,
req_messages: list[dict],
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
use_funcs: list[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()
@@ -169,7 +169,7 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
query: core_entities.Query,
model: entities.LLMModelInfo,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
funcs: typing.List[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行

View File

@@ -7,7 +7,7 @@ from . import chatcmpl
from .. import requester
from ....core import entities as core_entities
from ... import entities as llm_entities
from ...tools import entities as tools_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
@@ -23,7 +23,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
query: core_entities.Query,
req_messages: list[dict],
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
use_funcs: list[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token()

View File

@@ -11,7 +11,7 @@ import ollama
from .. import errors, requester
from ... import entities as llm_entities
from ...tools import entities as tools_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
from ....core import entities as core_entities
REQUESTER_NAME: str = 'ollama-chat'
@@ -42,7 +42,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
query: core_entities.Query,
req_messages: list[dict],
use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None,
use_funcs: list[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
args = extra_args.copy()
@@ -108,7 +108,7 @@ class OllamaChatCompletions(requester.LLMAPIRequester):
query: core_entities.Query,
model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None,
funcs: typing.List[resource_tool.LLMTool] = None,
extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message:
req_messages: list = []

View File

@@ -18,13 +18,15 @@ class LocalAgentRunner(runner.RequestRunner):
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
use_llm_model = await self.ap.model_mgr.get_model_by_uuid(query.use_llm_model_uuid)
# 首次请求
msg = await query.use_llm_model.requester.invoke_llm(
msg = await use_llm_model.requester.invoke_llm(
query,
query.use_llm_model,
use_llm_model,
req_messages,
query.use_funcs,
extra_args=query.use_llm_model.model_entity.extra_args,
extra_args=use_llm_model.model_entity.extra_args,
)
yield msg
@@ -61,12 +63,12 @@ class LocalAgentRunner(runner.RequestRunner):
req_messages.append(err_msg)
# 处理完所有调用,再次请求
msg = await query.use_llm_model.requester.invoke_llm(
msg = await use_llm_model.requester.invoke_llm(
query,
query.use_llm_model,
use_llm_model,
req_messages,
query.use_funcs,
extra_args=query.use_llm_model.model_entity.extra_args,
extra_args=use_llm_model.model_entity.extra_args,
)
yield msg

View File

@@ -3,7 +3,8 @@ from __future__ import annotations
import asyncio
from ...core import app, entities as core_entities
from ...provider import entities as provider_entities
from langbot_plugin.api.entities.builtin.provider import message as provider_message, prompt as provider_prompt
import langbot_plugin.api.entities.builtin.provider.session as provider_session
class SessionManager:
@@ -11,7 +12,7 @@ class SessionManager:
ap: app.Application
session_list: list[core_entities.Session]
session_list: list[provider_session.Session]
def __init__(self, ap: app.Application):
self.ap = ap
@@ -20,7 +21,7 @@ class SessionManager:
async def initialize(self):
pass
async def get_session(self, query: core_entities.Query) -> core_entities.Session:
async def get_session(self, query: core_entities.Query) -> provider_session.Session:
"""获取会话"""
for session in self.session_list:
if query.launcher_type == session.launcher_type and query.launcher_id == session.launcher_id:
@@ -28,7 +29,7 @@ class SessionManager:
session_concurrency = self.ap.instance_config.data['concurrency']['session']
session = core_entities.Session(
session = provider_session.Session(
launcher_type=query.launcher_type,
launcher_id=query.launcher_id,
semaphore=asyncio.Semaphore(session_concurrency),
@@ -39,11 +40,11 @@ class SessionManager:
async def get_conversation(
self,
query: core_entities.Query,
session: core_entities.Session,
session: provider_session.Session,
prompt_config: list[dict],
pipeline_uuid: str,
bot_uuid: str,
) -> core_entities.Conversation:
) -> provider_session.Conversation:
"""获取对话或创建对话"""
if not session.conversations:
@@ -53,20 +54,17 @@ class SessionManager:
prompt_messages = []
for prompt_message in prompt_config:
prompt_messages.append(provider_entities.Message(**prompt_message))
prompt_messages.append(provider_message.Message(**prompt_message))
prompt = provider_entities.Prompt(
prompt = provider_prompt.Prompt(
name='default',
messages=prompt_messages,
)
if session.using_conversation is None or session.using_conversation.pipeline_uuid != pipeline_uuid:
conversation = core_entities.Conversation(
conversation = provider_session.Conversation(
prompt=prompt,
messages=[],
use_funcs=await self.ap.tool_mgr.get_all_functions(
plugin_enabled=True,
),
pipeline_uuid=pipeline_uuid,
bot_uuid=bot_uuid,
)

View File

@@ -1,31 +0,0 @@
from __future__ import annotations
import typing
import pydantic.v1 as pydantic
class LLMFunction(pydantic.BaseModel):
"""函数"""
name: str
"""函数名"""
human_desc: str
description: str
"""给LLM识别的函数描述"""
parameters: dict
func: typing.Callable
"""供调用的python异步方法
此异步方法第一个参数接收当前请求的query对象可以从其中取出session等信息。
query参数不在parameters中但在调用时会自动传入。
但在当前版本中,插件提供的内容函数都是同步的,且均为请求无关的,故在此版本的实现(以及考虑了向后兼容性的版本)中,
对插件的内容函数进行封装并存到这里来。
"""
class Config:
arbitrary_types_allowed = True

View File

@@ -4,7 +4,7 @@ import abc
import typing
from ...core import app, entities as core_entities
from . import entities as tools_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
preregistered_loaders: list[typing.Type[ToolLoader]] = []
@@ -35,7 +35,7 @@ class ToolLoader(abc.ABC):
pass
@abc.abstractmethod
async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]:
async def get_tools(self, enabled: bool = True) -> list[resource_tool.LLMTool]:
"""获取所有工具"""
pass

View File

@@ -7,8 +7,9 @@ from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
from mcp.client.sse import sse_client
from .. import loader, entities as tools_entities
from .. import loader
from ....core import app, entities as core_entities
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
class RuntimeMCPSession:
@@ -24,7 +25,7 @@ class RuntimeMCPSession:
exit_stack: AsyncExitStack
functions: list[tools_entities.LLMFunction] = []
functions: list[resource_tool.LLMTool] = []
def __init__(self, server_name: str, server_config: dict, ap: app.Application):
self.server_name = server_name
@@ -91,7 +92,7 @@ class RuntimeMCPSession:
func.__name__ = tool.name
self.functions.append(
tools_entities.LLMFunction(
resource_tool.LLMTool(
name=tool.name,
human_desc=tool.description,
description=tool.description,
@@ -114,7 +115,7 @@ class MCPLoader(loader.ToolLoader):
sessions: dict[str, RuntimeMCPSession] = {}
_last_listed_functions: list[tools_entities.LLMFunction] = []
_last_listed_functions: list[resource_tool.LLMTool] = []
def __init__(self, ap: app.Application):
super().__init__(ap)
@@ -130,7 +131,7 @@ class MCPLoader(loader.ToolLoader):
# self.ap.event_loop.create_task(session.initialize())
self.sessions[server_config['name']] = session
async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]:
async def get_tools(self, enabled: bool = True) -> list[resource_tool.LLMTool]:
all_functions = []
for session in self.sessions.values():

View File

@@ -3,9 +3,10 @@ from __future__ import annotations
import typing
import traceback
from .. import loader, entities as tools_entities
from .. import loader
from ....core import entities as core_entities
from ....plugin import context as plugin_context
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
@loader.loader_class('plugin-tool-loader')
@@ -15,9 +16,9 @@ class PluginToolLoader(loader.ToolLoader):
本加载器中不存储工具信息,仅负责从插件系统中获取工具信息。
"""
async def get_tools(self, enabled: bool = True) -> list[tools_entities.LLMFunction]:
async def get_tools(self, enabled: bool = True) -> list[resource_tool.LLMTool]:
# 从插件系统获取工具(内容函数)
all_functions: list[tools_entities.LLMFunction] = []
all_functions: list[resource_tool.LLMTool] = []
for plugin in self.ap.plugin_mgr.plugins(
enabled=enabled, status=plugin_context.RuntimeContainerStatus.INITIALIZED
@@ -38,7 +39,7 @@ class PluginToolLoader(loader.ToolLoader):
async def _get_function_and_plugin(
self, name: str
) -> typing.Tuple[tools_entities.LLMFunction, plugin_context.BasePlugin]:
) -> typing.Tuple[resource_tool.LLMTool, plugin_context.BasePlugin]:
"""获取函数和插件实例"""
for plugin in self.ap.plugin_mgr.plugins(
enabled=True, status=plugin_context.RuntimeContainerStatus.INITIALIZED

View File

@@ -3,9 +3,10 @@ from __future__ import annotations
import typing
from ...core import app, entities as core_entities
from . import entities, loader as tools_loader
from . import loader as tools_loader
from ...utils import importutil
from . import loaders
import langbot_plugin.api.entities.builtin.resource.tool as resource_tool
importutil.import_modules_in_pkg(loaders)
@@ -28,16 +29,16 @@ class ToolManager:
await loader_inst.initialize()
self.loaders.append(loader_inst)
async def get_all_functions(self, plugin_enabled: bool = None) -> list[entities.LLMFunction]:
async def get_all_functions(self, plugin_enabled: bool = None) -> list[resource_tool.LLMTool]:
"""获取所有函数"""
all_functions: list[entities.LLMFunction] = []
all_functions: list[resource_tool.LLMTool] = []
for loader in self.loaders:
all_functions.extend(await loader.get_tools(plugin_enabled))
return all_functions
async def generate_tools_for_openai(self, use_funcs: list[entities.LLMFunction]) -> list:
async def generate_tools_for_openai(self, use_funcs: list[resource_tool.LLMTool]) -> list:
"""生成函数列表"""
tools = []
@@ -54,7 +55,7 @@ class ToolManager:
return tools
async def generate_tools_for_anthropic(self, use_funcs: list[entities.LLMFunction]) -> list:
async def generate_tools_for_anthropic(self, use_funcs: list[resource_tool.LLMTool]) -> list:
"""为anthropic生成函数列表
e.g.

View File

@@ -204,9 +204,9 @@ async def get_slack_image_to_base64(pic_url: str, bot_token: str):
try:
async with aiohttp.ClientSession() as session:
async with session.get(pic_url, headers=headers) as resp:
mime_type = resp.headers.get("Content-Type", "application/octet-stream")
mime_type = resp.headers.get('Content-Type', 'application/octet-stream')
file_bytes = await resp.read()
base64_str = base64.b64encode(file_bytes).decode("utf-8")
return f"data:{mime_type};base64,{base64_str}"
base64_str = base64.b64encode(file_bytes).decode('utf-8')
return f'data:{mime_type};base64,{base64_str}'
except Exception as e:
raise (e)
raise (e)

View File

@@ -32,7 +32,7 @@ def import_dir(path: str):
rel_path = full_path.replace(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '')
rel_path = rel_path[1:]
rel_path = rel_path.replace('/', '.')[:-3]
rel_path = rel_path.replace("\\",".")
rel_path = rel_path.replace('\\', '.')
importlib.import_module(rel_path)