perf: ruff format & remove stream params in requester

This commit is contained in:
Junyan Qin
2025-08-03 13:08:51 +08:00
parent 52280d7a05
commit 377d455ec1
39 changed files with 685 additions and 856 deletions

View File

@@ -3,7 +3,6 @@ import json
import time import time
from typing import Callable from typing import Callable
import dingtalk_stream # type: ignore import dingtalk_stream # type: ignore
from dingtalk_stream import AckMessage, ChatbotHandler, CallbackHandler, CallbackMessage, ChatbotMessage, AICardReplier
from .EchoHandler import EchoTextHandler from .EchoHandler import EchoTextHandler
from .dingtalkevent import DingTalkEvent from .dingtalkevent import DingTalkEvent
import httpx import httpx
@@ -254,24 +253,23 @@ class DingTalkClient:
await self.logger.error(f'failed to send proactive massage to group: {traceback.format_exc()}') await self.logger.error(f'failed to send proactive massage to group: {traceback.format_exc()}')
raise Exception(f'failed to send proactive massage to group: {traceback.format_exc()}') raise Exception(f'failed to send proactive massage to group: {traceback.format_exc()}')
async def create_and_card(self, temp_card_id: str, incoming_message: dingtalk_stream.ChatbotMessage,quote_origin:bool=False): async def create_and_card(
content_key = "content" self, temp_card_id: str, incoming_message: dingtalk_stream.ChatbotMessage, quote_origin: bool = False
card_data = {content_key: ""} ):
content_key = 'content'
card_data = {content_key: ''}
card_instance = dingtalk_stream.AICardReplier( card_instance = dingtalk_stream.AICardReplier(self.client, incoming_message)
self.client, incoming_message
)
# print(card_instance) # print(card_instance)
# 先投放卡片: https://open.dingtalk.com/document/orgapp/create-and-deliver-cards # 先投放卡片: https://open.dingtalk.com/document/orgapp/create-and-deliver-cards
card_instance_id = await card_instance.async_create_and_deliver_card( card_instance_id = await card_instance.async_create_and_deliver_card(
temp_card_id, card_data, temp_card_id,
card_data,
) )
return card_instance, card_instance_id return card_instance, card_instance_id
async def send_card_message(self, async def send_card_message(self, card_instance, card_instance_id: str, content: str, is_final: bool):
card_instance, content_key = 'content'
card_instance_id: str,content: str,is_final: bool):
content_key = "content"
try: try:
await card_instance.async_streaming( await card_instance.async_streaming(
card_instance_id, card_instance_id,
@@ -286,16 +284,12 @@ class DingTalkClient:
await card_instance.async_streaming( await card_instance.async_streaming(
card_instance_id, card_instance_id,
content_key=content_key, content_key=content_key,
content_value="", content_value='',
append=False, append=False,
finished=is_final, finished=is_final,
failed=True, failed=True,
) )
async def start(self): async def start(self):
"""启动 WebSocket 连接,监听消息""" """启动 WebSocket 连接,监听消息"""
await self.client.start() await self.client.start()

View File

@@ -1 +1 @@
from .client import WeChatPadClient from .client import WeChatPadClient as WeChatPadClient

View File

@@ -1,4 +1,4 @@
from libs.wechatpad_api.util.http_util import async_request, post_json from libs.wechatpad_api.util.http_util import post_json
class ChatRoomApi: class ChatRoomApi:
@@ -7,8 +7,6 @@ class ChatRoomApi:
self.token = token self.token = token
def get_chatroom_member_detail(self, chatroom_name): def get_chatroom_member_detail(self, chatroom_name):
params = { params = {'ChatRoomName': chatroom_name}
"ChatRoomName": chatroom_name
}
url = self.base_url + '/group/GetChatroomMemberDetail' url = self.base_url + '/group/GetChatroomMemberDetail'
return post_json(url, token=self.token, data=params) return post_json(url, token=self.token, data=params)

View File

@@ -1,32 +1,23 @@
from libs.wechatpad_api.util.http_util import async_request, post_json from libs.wechatpad_api.util.http_util import post_json
import httpx import httpx
import base64 import base64
class DownloadApi: class DownloadApi:
def __init__(self, base_url, token): def __init__(self, base_url, token):
self.base_url = base_url self.base_url = base_url
self.token = token self.token = token
def send_download(self, aeskey, file_type, file_url): def send_download(self, aeskey, file_type, file_url):
json_data = { json_data = {'AesKey': aeskey, 'FileType': file_type, 'FileURL': file_url}
"AesKey": aeskey, url = self.base_url + '/message/SendCdnDownload'
"FileType": file_type,
"FileURL": file_url
}
url = self.base_url + "/message/SendCdnDownload"
return post_json(url, token=self.token, data=json_data) return post_json(url, token=self.token, data=json_data)
def get_msg_voice(self, buf_id, length, new_msgid): def get_msg_voice(self, buf_id, length, new_msgid):
json_data = { json_data = {'Bufid': buf_id, 'Length': length, 'NewMsgId': new_msgid, 'ToUserName': ''}
"Bufid": buf_id, url = self.base_url + '/message/GetMsgVoice'
"Length": length,
"NewMsgId": new_msgid,
"ToUserName": ""
}
url = self.base_url + "/message/GetMsgVoice"
return post_json(url, token=self.token, data=json_data) return post_json(url, token=self.token, data=json_data)
async def download_url_to_base64(self, download_url): async def download_url_to_base64(self, download_url):
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get(download_url) response = await client.get(download_url)

View File

@@ -1,11 +1,6 @@
from libs.wechatpad_api.util.http_util import post_json,async_request
from typing import List, Dict, Any, Optional
class FriendApi: class FriendApi:
"""联系人API类处理所有与联系人相关的操作""" """联系人API类处理所有与联系人相关的操作"""
def __init__(self, base_url: str, token: str): def __init__(self, base_url: str, token: str):
self.base_url = base_url self.base_url = base_url
self.token = token self.token = token

View File

@@ -1,37 +1,34 @@
from libs.wechatpad_api.util.http_util import async_request,post_json,get_json from libs.wechatpad_api.util.http_util import post_json, get_json
class LoginApi: class LoginApi:
def __init__(self, base_url: str, token: str = None, admin_key: str = None): def __init__(self, base_url: str, token: str = None, admin_key: str = None):
''' """
Args: Args:
base_url: 原始路径 base_url: 原始路径
token: token token: token
admin_key: 管理员key admin_key: 管理员key
''' """
self.base_url = base_url self.base_url = base_url
self.token = token self.token = token
# self.admin_key = admin_key # self.admin_key = admin_key
def get_token(self, admin_key, day: int = 365): def get_token(self, admin_key, day: int = 365):
# 获取普通token # 获取普通token
url = f"{self.base_url}/admin/GenAuthKey1" url = f'{self.base_url}/admin/GenAuthKey1'
json_data = { json_data = {'Count': 1, 'Days': day}
"Count": 1,
"Days": day
}
return post_json(base_url=url, token=admin_key, data=json_data) return post_json(base_url=url, token=admin_key, data=json_data)
def get_login_qr(self, Proxy: str = ""): def get_login_qr(self, Proxy: str = ''):
''' """
Args: Args:
Proxy:异地使用时代理 Proxy:异地使用时代理
Returns:json数据 Returns:json数据
''' """
""" """
{ {
@@ -50,53 +47,36 @@ class LoginApi:
""" """
# 获取登录二维码 # 获取登录二维码
url = f"{self.base_url}/login/GetLoginQrCodeNew" url = f'{self.base_url}/login/GetLoginQrCodeNew'
check = False check = False
if Proxy != "": if Proxy != '':
check = True check = True
json_data = { json_data = {'Check': check, 'Proxy': Proxy}
"Check": check,
"Proxy": Proxy
}
return post_json(base_url=url, token=self.token, data=json_data) return post_json(base_url=url, token=self.token, data=json_data)
def get_login_status(self): def get_login_status(self):
# 获取登录状态 # 获取登录状态
url = f'{self.base_url}/login/GetLoginStatus' url = f'{self.base_url}/login/GetLoginStatus'
return get_json(base_url=url, token=self.token) return get_json(base_url=url, token=self.token)
def logout(self): def logout(self):
# 退出登录 # 退出登录
url = f'{self.base_url}/login/LogOut' url = f'{self.base_url}/login/LogOut'
return post_json(base_url=url, token=self.token) return post_json(base_url=url, token=self.token)
def wake_up_login(self, Proxy: str = ''):
def wake_up_login(self, Proxy: str = ""):
# 唤醒登录 # 唤醒登录
url = f'{self.base_url}/login/WakeUpLogin' url = f'{self.base_url}/login/WakeUpLogin'
check = False check = False
if Proxy != "": if Proxy != '':
check = True check = True
json_data = { json_data = {'Check': check, 'Proxy': ''}
"Check": check,
"Proxy": ""
}
return post_json(base_url=url, token=self.token, data=json_data) return post_json(base_url=url, token=self.token, data=json_data)
def login(self, admin_key): def login(self, admin_key):
login_status = self.get_login_status() login_status = self.get_login_status()
if login_status["Code"] == 300 and login_status["Text"] == "你已退出微信": if login_status['Code'] == 300 and login_status['Text'] == '你已退出微信':
print("token已经失效重新获取") print('token已经失效重新获取')
token_data = self.get_token(admin_key) token_data = self.get_token(admin_key)
self.token = token_data["Data"][0] self.token = token_data['Data'][0]

View File

@@ -1,5 +1,4 @@
from libs.wechatpad_api.util.http_util import post_json
from libs.wechatpad_api.util.http_util import async_request, post_json
class MessageApi: class MessageApi:
@@ -8,7 +7,7 @@ class MessageApi:
self.token = token self.token = token
def post_text(self, to_wxid, content, ats: list = []): def post_text(self, to_wxid, content, ats: list = []):
''' """
Args: Args:
app_id: 微信id app_id: 微信id
@@ -18,106 +17,64 @@ class MessageApi:
Returns: Returns:
''' """
url = self.base_url + "/message/SendTextMessage" url = self.base_url + '/message/SendTextMessage'
"""发送文字消息""" """发送文字消息"""
json_data = { json_data = {
"MsgItem": [ 'MsgItem': [
{ {'AtWxIDList': ats, 'ImageContent': '', 'MsgType': 0, 'TextContent': content, 'ToUserName': to_wxid}
"AtWxIDList": ats,
"ImageContent": "",
"MsgType": 0,
"TextContent": content,
"ToUserName": to_wxid
}
] ]
} }
return post_json(base_url=url, token=self.token, data=json_data) return post_json(base_url=url, token=self.token, data=json_data)
def post_image(self, to_wxid, img_url, ats: list = []): def post_image(self, to_wxid, img_url, ats: list = []):
"""发送图片消息""" """发送图片消息"""
# 这里好像可以尝试发送多个暂时未测试 # 这里好像可以尝试发送多个暂时未测试
json_data = { json_data = {
"MsgItem": [ 'MsgItem': [
{ {'AtWxIDList': ats, 'ImageContent': img_url, 'MsgType': 0, 'TextContent': '', 'ToUserName': to_wxid}
"AtWxIDList": ats,
"ImageContent": img_url,
"MsgType": 0,
"TextContent": '',
"ToUserName": to_wxid
}
] ]
} }
url = self.base_url + "/message/SendImageMessage" url = self.base_url + '/message/SendImageMessage'
return post_json(base_url=url, token=self.token, data=json_data) return post_json(base_url=url, token=self.token, data=json_data)
def post_voice(self, to_wxid, voice_data, voice_forma, voice_duration): def post_voice(self, to_wxid, voice_data, voice_forma, voice_duration):
"""发送语音消息""" """发送语音消息"""
json_data = { json_data = {
"ToUserName": to_wxid, 'ToUserName': to_wxid,
"VoiceData": voice_data, 'VoiceData': voice_data,
"VoiceFormat": voice_forma, 'VoiceFormat': voice_forma,
"VoiceSecond": voice_duration 'VoiceSecond': voice_duration,
} }
url = self.base_url + "/message/SendVoice" url = self.base_url + '/message/SendVoice'
return post_json(base_url=url, token=self.token, data=json_data) return post_json(base_url=url, token=self.token, data=json_data)
def post_name_card(self, alias, to_wxid, nick_name, name_card_wxid, flag): def post_name_card(self, alias, to_wxid, nick_name, name_card_wxid, flag):
"""发送名片消息""" """发送名片消息"""
param = { param = {
"CardAlias": alias, 'CardAlias': alias,
"CardFlag": flag, 'CardFlag': flag,
"CardNickName": nick_name, 'CardNickName': nick_name,
"CardWxId": name_card_wxid, 'CardWxId': name_card_wxid,
"ToUserName": to_wxid 'ToUserName': to_wxid,
} }
url = f"{self.base_url}/message/ShareCardMessage" url = f'{self.base_url}/message/ShareCardMessage'
return post_json(base_url=url, token=self.token, data=param) return post_json(base_url=url, token=self.token, data=param)
def post_emoji(self, to_wxid, emoji_md5, emoji_size: int = 0): def post_emoji(self, to_wxid, emoji_md5, emoji_size: int = 0):
"""发送emoji消息""" """发送emoji消息"""
json_data = { json_data = {'EmojiList': [{'EmojiMd5': emoji_md5, 'EmojiSize': emoji_size, 'ToUserName': to_wxid}]}
"EmojiList": [ url = f'{self.base_url}/message/SendEmojiMessage'
{
"EmojiMd5": emoji_md5,
"EmojiSize": emoji_size,
"ToUserName": to_wxid
}
]
}
url = f"{self.base_url}/message/SendEmojiMessage"
return post_json(base_url=url, token=self.token, data=json_data) return post_json(base_url=url, token=self.token, data=json_data)
def post_app_msg(self, to_wxid, xml_data, contenttype: int = 0): def post_app_msg(self, to_wxid, xml_data, contenttype: int = 0):
"""发送appmsg消息""" """发送appmsg消息"""
json_data = { json_data = {'AppList': [{'ContentType': contenttype, 'ContentXML': xml_data, 'ToUserName': to_wxid}]}
"AppList": [ url = f'{self.base_url}/message/SendAppMessage'
{
"ContentType": contenttype,
"ContentXML": xml_data,
"ToUserName": to_wxid
}
]
}
url = f"{self.base_url}/message/SendAppMessage"
return post_json(base_url=url, token=self.token, data=json_data) return post_json(base_url=url, token=self.token, data=json_data)
def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time): def revoke_msg(self, to_wxid, msg_id, new_msg_id, create_time):
"""撤回消息""" """撤回消息"""
param = { param = {'ClientMsgId': msg_id, 'CreateTime': create_time, 'NewMsgId': new_msg_id, 'ToUserName': to_wxid}
"ClientMsgId": msg_id, url = f'{self.base_url}/message/RevokeMsg'
"CreateTime": create_time,
"NewMsgId": new_msg_id,
"ToUserName": to_wxid
}
url = f"{self.base_url}/message/RevokeMsg"
return post_json(base_url=url, token=self.token, data=param) return post_json(base_url=url, token=self.token, data=param)

View File

@@ -1,10 +1,9 @@
import requests import requests
import aiohttp
def post_json(base_url, token, data=None): def post_json(base_url, token, data=None):
headers = { headers = {'Content-Type': 'application/json'}
'Content-Type': 'application/json'
}
url = base_url + f'?key={token}' url = base_url + f'?key={token}'
@@ -18,14 +17,12 @@ def post_json(base_url, token, data=None):
else: else:
raise RuntimeError(response.text) raise RuntimeError(response.text)
except Exception as e: except Exception as e:
print(f"http请求失败, url={url}, exception={e}") print(f'http请求失败, url={url}, exception={e}')
raise RuntimeError(str(e)) raise RuntimeError(str(e))
def get_json(base_url, token):
headers = {
'Content-Type': 'application/json'
}
def get_json(base_url, token):
headers = {'Content-Type': 'application/json'}
url = base_url + f'?key={token}' url = base_url + f'?key={token}'
@@ -39,12 +36,9 @@ def get_json(base_url, token):
else: else:
raise RuntimeError(response.text) raise RuntimeError(response.text)
except Exception as e: except Exception as e:
print(f"http请求失败, url={url}, exception={e}") print(f'http请求失败, url={url}, exception={e}')
raise RuntimeError(str(e)) raise RuntimeError(str(e))
import aiohttp
import asyncio
async def async_request( async def async_request(
base_url: str, base_url: str,
@@ -53,7 +47,7 @@ async def async_request(
params: dict = None, params: dict = None,
# headers: dict = None, # headers: dict = None,
data: dict = None, data: dict = None,
json: dict = None json: dict = None,
): ):
""" """
通用异步请求函数 通用异步请求函数
@@ -67,18 +61,11 @@ async def async_request(
:param json: JSON数据 :param json: JSON数据
:return: 响应文本 :return: 响应文本
""" """
headers = { headers = {'Content-Type': 'application/json'}
'Content-Type': 'application/json' url = f'{base_url}?key={token_key}'
}
url = f"{base_url}?key={token_key}"
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.request( async with session.request(
method=method, method=method, url=url, params=params, headers=headers, data=data, json=json
url=url,
params=params,
headers=headers,
data=data,
json=json
) as response: ) as response:
response.raise_for_status() # 如果状态码不是200抛出异常 response.raise_for_status() # 如果状态码不是200抛出异常
result = await response.json() result = await response.json()
@@ -89,4 +76,3 @@ async def async_request(
# return await result # return await result
# else: # else:
# raise RuntimeError("请求失败",response.text) # raise RuntimeError("请求失败",response.text)

View File

@@ -14,8 +14,9 @@ class WebChatDebugRouterGroup(group.RouterGroup):
async def stream_generator(generator): async def stream_generator(generator):
async for message in generator: async for message in generator:
yield f"data: {json.dumps({'message': message})}\n\n" yield f'data: {json.dumps({"message": message})}\n\n'
yield "data: {\"type\": \"end\"}\n\n" yield 'data: {"type": "end"}\n\n'
try: try:
data = await quart.request.get_json() data = await quart.request.get_json()
session_type = data.get('session_type', 'person') session_type = data.get('session_type', 'person')
@@ -34,18 +35,18 @@ class WebChatDebugRouterGroup(group.RouterGroup):
return self.http_status(404, -1, 'WebChat adapter not found') return self.http_status(404, -1, 'WebChat adapter not found')
if is_stream: if is_stream:
generator = webchat_adapter.send_webchat_message(
generator = webchat_adapter.send_webchat_message(pipeline_uuid, session_type, message_chain_obj, is_stream) pipeline_uuid, session_type, message_chain_obj, is_stream
return quart.Response(
stream_generator(generator),
mimetype='text/event-stream'
) )
return quart.Response(stream_generator(generator), mimetype='text/event-stream')
else: else:
# result = await webchat_adapter.send_webchat_message(pipeline_uuid, session_type, message_chain_obj) # result = await webchat_adapter.send_webchat_message(pipeline_uuid, session_type, message_chain_obj)
result = None result = None
async for message in webchat_adapter.send_webchat_message(pipeline_uuid, session_type, message_chain_obj): async for message in webchat_adapter.send_webchat_message(
pipeline_uuid, session_type, message_chain_obj
):
result = message result = message
if result is not None: if result is not None:
return self.success( return self.success(
@@ -56,7 +57,6 @@ class WebChatDebugRouterGroup(group.RouterGroup):
else: else:
return self.http_status(400, -1, 'message is required') return self.http_status(400, -1, 'message is required')
except Exception as e: except Exception as e:
return self.http_status(500, -1, f'Internal server error: {str(e)}') return self.http_status(500, -1, f'Internal server error: {str(e)}')

View File

@@ -87,7 +87,9 @@ class Query(pydantic.BaseModel):
"""使用的函数,由前置处理器阶段设置""" """使用的函数,由前置处理器阶段设置"""
resp_messages: ( resp_messages: (
typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]] | typing.Optional[list[llm_entities.MessageChunk]] typing.Optional[list[llm_entities.Message]]
| typing.Optional[list[platform_message.MessageChain]]
| typing.Optional[list[llm_entities.MessageChunk]]
) = [] ) = []
"""由Process阶段生成的回复消息对象列表""" """由Process阶段生成的回复消息对象列表"""

View File

@@ -1,7 +1,6 @@
from __future__ import annotations from __future__ import annotations
import uuid import uuid
from itertools import accumulate
import typing import typing
import traceback import traceback
@@ -82,9 +81,7 @@ class ChatMessageHandler(handler.MessageHandler):
query.resp_message_chain.pop() query.resp_message_chain.pop()
query.resp_messages.append(result) query.resp_messages.append(result)
self.ap.logger.info( self.ap.logger.info(f'对话({query.query_id})流式响应: {self.cut_str(result.readable_str())}')
f'对话({query.query_id})流式响应: {self.cut_str(result.readable_str())}'
)
if result.content is not None: if result.content is not None:
text_length += len(result.content) text_length += len(result.content)

View File

@@ -3,12 +3,10 @@ from __future__ import annotations
import random import random
import asyncio import asyncio
from typing_inspection.typing_objects import is_final
from ...platform.types import events as platform_events from ...platform.types import events as platform_events
from ...platform.types import message as platform_message from ...platform.types import message as platform_message
from ...provider import entities as llm_entities
from .. import stage, entities from .. import stage, entities
from ...core import entities as core_entities from ...core import entities as core_entities
@@ -56,6 +54,4 @@ class SendResponseBackStage(stage.PipelineStage):
quote_origin=quote_origin, quote_origin=quote_origin,
) )
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -25,7 +25,6 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
logger: EventLogger logger: EventLogger
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
"""初始化适配器 """初始化适配器
@@ -94,8 +93,8 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_message.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None], callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
): ):
"""注册事件监听器 """注册事件监听器
@@ -107,8 +106,8 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[platform_message.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None], callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
): ):
"""注销事件监听器 """注销事件监听器
@@ -167,7 +166,7 @@ class EventConverter:
"""事件转换器基类""" """事件转换器基类"""
@staticmethod @staticmethod
def yiri2target(event: typing.Type[platform_message.Event]): def yiri2target(event: typing.Type[platform_events.Event]):
"""将源平台事件转换为目标平台事件 """将源平台事件转换为目标平台事件
Args: Args:
@@ -179,7 +178,7 @@ class EventConverter:
raise NotImplementedError raise NotImplementedError
@staticmethod @staticmethod
def target2yiri(event: typing.Any) -> platform_message.Event: def target2yiri(event: typing.Any) -> platform_events.Event:
"""将目标平台事件的调用参数转换为源平台的事件参数对象 """将目标平台事件的调用参数转换为源平台的事件参数对象
Args: Args:

View File

@@ -16,7 +16,6 @@ from ..logger import EventLogger
class AiocqhttpMessageConverter(adapter.MessageConverter): class AiocqhttpMessageConverter(adapter.MessageConverter):
@staticmethod @staticmethod
async def yiri2target( async def yiri2target(
message_chain: platform_message.MessageChain, message_chain: platform_message.MessageChain,
@@ -62,7 +61,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
for node in msg.node_list: for node in msg.node_list:
msg_list.extend((await AiocqhttpMessageConverter.yiri2target(node.message_chain))[0]) msg_list.extend((await AiocqhttpMessageConverter.yiri2target(node.message_chain))[0])
elif isinstance(msg, platform_message.File): elif isinstance(msg, platform_message.File):
msg_list.append({"type":"file", "data":{'file': msg.url, "name": msg.name}}) msg_list.append({'type': 'file', 'data': {'file': msg.url, 'name': msg.name}})
elif isinstance(msg, platform_message.Face): elif isinstance(msg, platform_message.Face):
if msg.face_type == 'face': if msg.face_type == 'face':
msg_list.append(aiocqhttp.MessageSegment.face(msg.face_id)) msg_list.append(aiocqhttp.MessageSegment.face(msg.face_id))
@@ -71,7 +70,6 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
elif msg.face_type == 'dice': elif msg.face_type == 'dice':
msg_list.append(aiocqhttp.MessageSegment.dice()) msg_list.append(aiocqhttp.MessageSegment.dice())
else: else:
msg_list.append(aiocqhttp.MessageSegment.text(str(msg))) msg_list.append(aiocqhttp.MessageSegment.text(str(msg)))
@@ -84,65 +82,149 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
def get_face_name(face_id): def get_face_name(face_id):
face_code_dict = { face_code_dict = {
"2": '好色', '2': '好色',
"4": "得意", "5": "流泪", "8": "", "9": "大哭", "10": "尴尬", "12": "调皮", "14": "微笑", "16": "", '4': '得意',
"21": "可爱", '5': '流泪',
"23": "傲慢", "24": "饥饿", "25": "", "26": "惊恐", "27": "流汗", "28": "憨笑", "29": "悠闲", '8': '',
"30": "奋斗", '9': '大哭',
"32": "疑问", "33": "", "34": "", "38": "敲打", "39": "再见", "41": "发抖", "42": "爱情", '10': '尴尬',
"43": "跳跳", '12': '调皮',
"49": "拥抱", "53": "蛋糕", "60": "咖啡", "63": "玫瑰", "66": "爱心", "74": "太阳", "75": "月亮", '14': '微笑',
"76": "", '16': '',
"78": "握手", "79": "胜利", "85": "飞吻", "89": "西瓜", "96": "冷汗", "97": "擦汗", "98": "抠鼻", '21': '可爱',
"99": "鼓掌", '23': '傲慢',
"100": "糗大了", "101": "坏笑", "102": "左哼哼", "103": "右哼哼", "104": "哈欠", "106": "委屈", '24': '饥饿',
"109": "左亲亲", '25': '',
"111": "可怜", "116": "示爱", "118": "抱拳", "120": "拳头", "122": "爱你", "123": "NO", "124": "OK", '26': '惊恐',
"125": "转圈", '27': '流汗',
"129": "挥手", "144": "喝彩", "147": "棒棒糖", "171": "", "173": "泪奔", "174": "无奈", "175": "卖萌", '28': '憨笑',
"176": "小纠结", "179": "doge", "180": "惊喜", "181": "骚扰", "182": "笑哭", "183": "我最美", '29': '悠闲',
"201": "点赞", '30': '奋斗',
"203": "托脸", "212": "托腮", "214": "啵啵", "219": "蹭一蹭", "222": "抱抱", "227": "拍手", '32': '疑问',
"232": "佛系", '33': '',
"240": "喷脸", "243": "甩头", "246": "加油抱抱", "262": "脑阔疼", "264": "捂脸", "265": "辣眼睛", '34': '',
"266": "哦哟", '38': '敲打',
"267": "头秃", "268": "问号脸", "269": "暗中观察", "270": "emm", "271": "吃瓜", "272": "呵呵哒", '39': '再见',
"273": "我酸了", '41': '发抖',
"277": "汪汪", "278": "", "281": "无眼笑", "282": "敬礼", "284": "面无表情", "285": "摸鱼", '42': '爱情',
"287": "", '43': '跳跳',
"289": "睁眼", "290": "敲开心", "293": "摸锦鲤", "294": "期待", "297": "拜谢", "298": "元宝", '49': '拥抱',
"299": "牛啊", '53': '蛋糕',
"305": "右亲亲", "306": "牛气冲天", "307": "喵喵", "314": "仔细分析", "315": "加油", "318": "崇拜", '60': '咖啡',
"319": "比心", '63': '玫瑰',
"320": "庆祝", "322": "拒绝", "324": "吃糖", "326": "生气" '66': '爱心',
'74': '太阳',
'75': '月亮',
'76': '',
'78': '握手',
'79': '胜利',
'85': '飞吻',
'89': '西瓜',
'96': '冷汗',
'97': '擦汗',
'98': '抠鼻',
'99': '鼓掌',
'100': '糗大了',
'101': '坏笑',
'102': '左哼哼',
'103': '右哼哼',
'104': '哈欠',
'106': '委屈',
'109': '左亲亲',
'111': '可怜',
'116': '示爱',
'118': '抱拳',
'120': '拳头',
'122': '爱你',
'123': 'NO',
'124': 'OK',
'125': '转圈',
'129': '挥手',
'144': '喝彩',
'147': '棒棒糖',
'171': '',
'173': '泪奔',
'174': '无奈',
'175': '卖萌',
'176': '小纠结',
'179': 'doge',
'180': '惊喜',
'181': '骚扰',
'182': '笑哭',
'183': '我最美',
'201': '点赞',
'203': '托脸',
'212': '托腮',
'214': '啵啵',
'219': '蹭一蹭',
'222': '抱抱',
'227': '拍手',
'232': '佛系',
'240': '喷脸',
'243': '甩头',
'246': '加油抱抱',
'262': '脑阔疼',
'264': '捂脸',
'265': '辣眼睛',
'266': '哦哟',
'267': '头秃',
'268': '问号脸',
'269': '暗中观察',
'270': 'emm',
'271': '吃瓜',
'272': '呵呵哒',
'273': '我酸了',
'277': '汪汪',
'278': '',
'281': '无眼笑',
'282': '敬礼',
'284': '面无表情',
'285': '摸鱼',
'287': '',
'289': '睁眼',
'290': '敲开心',
'293': '摸锦鲤',
'294': '期待',
'297': '拜谢',
'298': '元宝',
'299': '牛啊',
'305': '右亲亲',
'306': '牛气冲天',
'307': '喵喵',
'314': '仔细分析',
'315': '加油',
'318': '崇拜',
'319': '比心',
'320': '庆祝',
'322': '拒绝',
'324': '吃糖',
'326': '生气',
} }
return face_code_dict.get(face_id, '') return face_code_dict.get(face_id, '')
async def process_message_data(msg_data, reply_list): async def process_message_data(msg_data, reply_list):
if msg_data["type"] == "image": if msg_data['type'] == 'image':
image_base64, image_format = await image.qq_image_url_to_base64(msg_data["data"]['url']) image_base64, image_format = await image.qq_image_url_to_base64(msg_data['data']['url'])
reply_list.append( reply_list.append(platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}'))
platform_message.Image(base64=f'data:image/{image_format};base64,{image_base64}'))
elif msg_data["type"] == "text": elif msg_data['type'] == 'text':
reply_list.append(platform_message.Plain(text=msg_data["data"]["text"])) reply_list.append(platform_message.Plain(text=msg_data['data']['text']))
elif msg_data["type"] == "forward": # 这里来应该传入转发消息组暂时传入qoute elif msg_data['type'] == 'forward': # 这里来应该传入转发消息组暂时传入qoute
for forward_msg_datas in msg_data["data"]["content"]: for forward_msg_datas in msg_data['data']['content']:
for forward_msg_data in forward_msg_datas["message"]: for forward_msg_data in forward_msg_datas['message']:
await process_message_data(forward_msg_data, reply_list) await process_message_data(forward_msg_data, reply_list)
elif msg_data["type"] == "at": elif msg_data['type'] == 'at':
if msg_data["data"]['qq'] == 'all': if msg_data['data']['qq'] == 'all':
reply_list.append(platform_message.AtAll()) reply_list.append(platform_message.AtAll())
else: else:
reply_list.append( reply_list.append(
platform_message.At( platform_message.At(
target=msg_data["data"]['qq'], target=msg_data['data']['qq'],
) )
) )
yiri_msg_list = [] yiri_msg_list = []
yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now())) yiri_msg_list.append(platform_message.Source(id=message_id, time=datetime.datetime.now()))
@@ -161,10 +243,10 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
elif msg.type == 'text': elif msg.type == 'text':
yiri_msg_list.append(platform_message.Plain(text=msg.data['text'])) yiri_msg_list.append(platform_message.Plain(text=msg.data['text']))
elif msg.type == 'image': elif msg.type == 'image':
emoji_id = msg.data.get("emoji_package_id", None) emoji_id = msg.data.get('emoji_package_id', None)
if emoji_id: if emoji_id:
face_id = emoji_id face_id = emoji_id
face_name = msg.data.get("summary", '') face_name = msg.data.get('summary', '')
image_msg = platform_message.Face(face_id=face_id, face_name=face_name) image_msg = platform_message.Face(face_id=face_id, face_name=face_name)
else: else:
image_base64, image_format = await image.qq_image_url_to_base64(msg.data['url']) image_base64, image_format = await image.qq_image_url_to_base64(msg.data['url'])
@@ -178,14 +260,15 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
# await process_message_data(msg_data, yiri_msg_list) # await process_message_data(msg_data, yiri_msg_list)
pass pass
elif msg.type == 'reply': # 此处处理引用消息传入Qoute elif msg.type == 'reply': # 此处处理引用消息传入Qoute
msg_datas = await bot.get_msg(message_id=msg.data["id"]) msg_datas = await bot.get_msg(message_id=msg.data['id'])
for msg_data in msg_datas["message"]: for msg_data in msg_datas['message']:
await process_message_data(msg_data, reply_list) await process_message_data(msg_data, reply_list)
reply_msg = platform_message.Quote(message_id=msg.data["id"],sender_id=msg_datas["user_id"],origin=reply_list) reply_msg = platform_message.Quote(
message_id=msg.data['id'], sender_id=msg_datas['user_id'], origin=reply_list
)
yiri_msg_list.append(reply_msg) yiri_msg_list.append(reply_msg)
elif msg.type == 'file': elif msg.type == 'file':
@@ -193,7 +276,7 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
file_id = msg.data['file_id'] file_id = msg.data['file_id']
file_data = await bot.get_file(file_id=file_id) file_data = await bot.get_file(file_id=file_id)
file_name = file_data.get('file_name') file_name = file_data.get('file_name')
file_path = file_data.get('file') # file_path = file_data.get('file')
file_url = file_data.get('file_url') file_url = file_data.get('file_url')
file_size = file_data.get('file_size') file_size = file_data.get('file_size')
yiri_msg_list.append(platform_message.File(id=file_id, name=file_name, url=file_url, size=file_size)) yiri_msg_list.append(platform_message.File(id=file_id, name=file_name, url=file_url, size=file_size))
@@ -205,28 +288,16 @@ class AiocqhttpMessageConverter(adapter.MessageConverter):
yiri_msg_list.append(platform_message.Face(face_id=int(face_id), face_name=face_name.replace('/', ''))) yiri_msg_list.append(platform_message.Face(face_id=int(face_id), face_name=face_name.replace('/', '')))
elif msg.type == 'rps': elif msg.type == 'rps':
face_id = msg.data['result'] face_id = msg.data['result']
yiri_msg_list.append(platform_message.Face(face_type="rps",face_id=int(face_id),face_name='猜拳')) yiri_msg_list.append(platform_message.Face(face_type='rps', face_id=int(face_id), face_name='猜拳'))
elif msg.type == 'dice': elif msg.type == 'dice':
face_id = msg.data['result'] face_id = msg.data['result']
yiri_msg_list.append(platform_message.Face(face_type='dice', face_id=int(face_id), face_name='骰子')) yiri_msg_list.append(platform_message.Face(face_type='dice', face_id=int(face_id), face_name='骰子'))
chain = platform_message.MessageChain(yiri_msg_list) chain = platform_message.MessageChain(yiri_msg_list)
return chain return chain
class AiocqhttpEventConverter(adapter.EventConverter): class AiocqhttpEventConverter(adapter.EventConverter):
@staticmethod @staticmethod
async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int): async def yiri2target(event: platform_events.MessageEvent, bot_account_id: int):
@@ -236,8 +307,6 @@ class AiocqhttpEventConverter(adapter.EventConverter):
async def target2yiri(event: aiocqhttp.Event, bot=None): async def target2yiri(event: aiocqhttp.Event, bot=None):
yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id, bot) yiri_chain = await AiocqhttpMessageConverter.target2yiri(event.message, event.message_id, bot)
if event.message_type == 'group': if event.message_type == 'group':
permission = 'MEMBER' permission = 'MEMBER'
@@ -316,7 +385,6 @@ class AiocqhttpAdapter(adapter.MessagePlatformAdapter):
aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0] aiocq_msg = (await AiocqhttpMessageConverter.yiri2target(message))[0]
if target_type == 'group': if target_type == 'group':
await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg) await self.bot.send_group_msg(group_id=int(target_id), message=aiocq_msg)
elif target_type == 'person': elif target_type == 'person':
await self.bot.send_private_msg(user_id=int(target_id), message=aiocq_msg) await self.bot.send_private_msg(user_id=int(target_id), message=aiocq_msg)

View File

@@ -149,10 +149,10 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
quote_origin: bool = False, quote_origin: bool = False,
is_final: bool = False, is_final: bool = False,
): ):
event = await DingTalkEventConverter.yiri2target( # event = await DingTalkEventConverter.yiri2target(
message_source, # message_source,
) # )
incoming_message = event.incoming_message # incoming_message = event.incoming_message
# msg_id = incoming_message.message_id # msg_id = incoming_message.message_id
@@ -205,7 +205,6 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
self.bot.on_message('GroupMessage')(on_message) self.bot.on_message('GroupMessage')(on_message)
async def run_async(self): async def run_async(self):
await self.bot.start() await self.bot.start()
async def kill(self) -> bool: async def kill(self) -> bool:

View File

@@ -8,7 +8,6 @@ import base64
import uuid import uuid
import os import os
import datetime import datetime
import io
import aiohttp import aiohttp
@@ -101,12 +100,13 @@ class DiscordMessageConverter(adapter.MessageConverter):
filename = f'{uuid.uuid4()}.webp' filename = f'{uuid.uuid4()}.webp'
# 默认保持PNG # 默认保持PNG
except Exception as e: except Exception as e:
print(f"Error reading image file {clean_path}: {e}") print(f'Error reading image file {clean_path}: {e}')
continue # 跳过读取失败的文件 continue # 跳过读取失败的文件
if image_bytes: if image_bytes:
# 使用BytesIO创建文件对象避免路径问题 # 使用BytesIO创建文件对象避免路径问题
import io import io
image_files.append(discord.File(fp=io.BytesIO(image_bytes), filename=filename)) image_files.append(discord.File(fp=io.BytesIO(image_bytes), filename=filename))
elif isinstance(ele, platform_message.Plain): elif isinstance(ele, platform_message.Plain):
text_string += ele.text text_string += ele.text
@@ -279,7 +279,7 @@ class DiscordAdapter(adapter.MessagePlatformAdapter):
await channel.send(**args) await channel.send(**args)
except Exception as e: except Exception as e:
await self.logger.error(f"Discord send_message failed: {e}") await self.logger.error(f'Discord send_message failed: {e}')
raise e raise e
async def reply_message( async def reply_message(

View File

@@ -9,7 +9,6 @@ import re
import base64 import base64
import uuid import uuid
import json import json
import time
import datetime import datetime
import hashlib import hashlib
from Crypto.Cipher import AES from Crypto.Cipher import AES
@@ -394,14 +393,14 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
if 'im.message.receive_v1' == type: if 'im.message.receive_v1' == type:
try: try:
event = await self.event_converter.target2yiri(p2v1, self.api_client) event = await self.event_converter.target2yiri(p2v1, self.api_client)
except Exception as e: except Exception:
await self.logger.error(f'Error in lark callback: {traceback.format_exc()}') await self.logger.error(f'Error in lark callback: {traceback.format_exc()}')
if event.__class__ in self.listeners: if event.__class__ in self.listeners:
await self.listeners[event.__class__](event, self) await self.listeners[event.__class__](event, self)
return {'code': 200, 'message': 'ok'} return {'code': 200, 'message': 'ok'}
except Exception as e: except Exception:
await self.logger.error(f'Error in lark callback: {traceback.format_exc()}') await self.logger.error(f'Error in lark callback: {traceback.format_exc()}')
return {'code': 500, 'message': 'error'} return {'code': 500, 'message': 'error'}
@@ -559,10 +558,10 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
elif ele['tag'] == 'md': elif ele['tag'] == 'md':
text_message += ele['text'] text_message += ele['text']
content = { # content = {
'type': 'card_json', # 'type': 'card_json',
'data': {'card_id': self.card_id_dict[message_id], 'elements': {'content': text_message}}, # 'data': {'card_id': self.card_id_dict[message_id], 'elements': {'content': text_message}},
} # }
request: ContentCardElementRequest = ( request: ContentCardElementRequest = (
ContentCardElementRequest.builder() ContentCardElementRequest.builder()

View File

@@ -72,8 +72,9 @@ class NakuruProjectMessageConverter(adapter_model.MessageConverter):
content=content_list, content=content_list,
) )
nakuru_forward_node_list.append(nakuru_forward_node) nakuru_forward_node_list.append(nakuru_forward_node)
except Exception as e: except Exception:
import traceback import traceback
traceback.print_exc() traceback.print_exc()
nakuru_msg_list.append(nakuru_forward_node_list) nakuru_msg_list.append(nakuru_forward_node_list)
@@ -276,7 +277,7 @@ class NakuruAdapter(adapter_model.MessagePlatformAdapter):
# 注册监听器 # 注册监听器
self.bot.receiver(source_cls.__name__)(listener_wrapper) self.bot.receiver(source_cls.__name__)(listener_wrapper)
except Exception as e: except Exception as e:
self.logger.error(f"Error in nakuru register_listener: {traceback.format_exc()}") self.logger.error(f'Error in nakuru register_listener: {traceback.format_exc()}')
raise e raise e
def unregister_listener( def unregister_listener(

View File

@@ -125,8 +125,8 @@ class OfficialAccountAdapter(adapter.MessagePlatformAdapter):
self.bot_account_id = event.receiver_id self.bot_account_id = event.receiver_id
try: try:
return await callback(await self.event_converter.target2yiri(event), self) return await callback(await self.event_converter.target2yiri(event), self)
except Exception as e: except Exception:
await self.logger.error(f"Error in officialaccount callback: {traceback.format_exc()}") await self.logger.error(f'Error in officialaccount callback: {traceback.format_exc()}')
if event_type == platform_events.FriendMessage: if event_type == platform_events.FriendMessage:
self.bot.on_message('text')(on_message) self.bot.on_message('text')(on_message)

View File

@@ -501,7 +501,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
for event_handler in event_handler_mapping[event_type]: for event_handler in event_handler_mapping[event_type]:
setattr(self.bot, event_handler, wrapper) setattr(self.bot, event_handler, wrapper)
except Exception as e: except Exception as e:
self.logger.error(f"Error in qqbotpy callback: {traceback.format_exc()}") self.logger.error(f'Error in qqbotpy callback: {traceback.format_exc()}')
raise e raise e
def unregister_listener( def unregister_listener(

View File

@@ -154,10 +154,7 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
raise ParamNotEnoughError('QQ官方机器人缺少相关配置项请查看文档或联系管理员') raise ParamNotEnoughError('QQ官方机器人缺少相关配置项请查看文档或联系管理员')
self.bot = QQOfficialClient( self.bot = QQOfficialClient(
app_id=config['appid'], app_id=config['appid'], secret=config['secret'], token=config['token'], logger=self.logger
secret=config['secret'],
token=config['token'],
logger=self.logger
) )
async def reply_message( async def reply_message(
@@ -224,8 +221,8 @@ class QQOfficialAdapter(adapter.MessagePlatformAdapter):
self.bot_account_id = 'justbot' self.bot_account_id = 'justbot'
try: try:
return await callback(await self.event_converter.target2yiri(event), self) return await callback(await self.event_converter.target2yiri(event), self)
except Exception as e: except Exception:
await self.logger.error(f"Error in qqofficial callback: {traceback.format_exc()}") await self.logger.error(f'Error in qqofficial callback: {traceback.format_exc()}')
if event_type == platform_events.FriendMessage: if event_type == platform_events.FriendMessage:
self.bot.on_message('DIRECT_MESSAGE_CREATE')(on_message) self.bot.on_message('DIRECT_MESSAGE_CREATE')(on_message)

View File

@@ -104,7 +104,9 @@ class SlackAdapter(adapter.MessagePlatformAdapter):
if missing_keys: if missing_keys:
raise ParamNotEnoughError('Slack机器人缺少相关配置项请查看文档或联系管理员') raise ParamNotEnoughError('Slack机器人缺少相关配置项请查看文档或联系管理员')
self.bot = SlackClient(bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger) self.bot = SlackClient(
bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger
)
async def reply_message( async def reply_message(
self, self,
@@ -139,8 +141,8 @@ class SlackAdapter(adapter.MessagePlatformAdapter):
self.bot_account_id = 'SlackBot' self.bot_account_id = 'SlackBot'
try: try:
return await callback(await self.event_converter.target2yiri(event, self.bot), self) return await callback(await self.event_converter.target2yiri(event, self.bot), self)
except Exception as e: except Exception:
await self.logger.error(f"Error in slack callback: {traceback.format_exc()}") await self.logger.error(f'Error in slack callback: {traceback.format_exc()}')
if event_type == platform_events.FriendMessage: if event_type == platform_events.FriendMessage:
self.bot.on_message('im')(on_message) self.bot.on_message('im')(on_message)

View File

@@ -1,6 +1,5 @@
from __future__ import annotations from __future__ import annotations
import time
import telegram import telegram
import telegram.ext import telegram.ext
@@ -166,7 +165,7 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
lb_event = await self.event_converter.target2yiri(update, self.bot, self.bot_account_id) lb_event = await self.event_converter.target2yiri(update, self.bot, self.bot_account_id)
await self.listeners[type(lb_event)](lb_event, self) await self.listeners[type(lb_event)](lb_event, self)
await self.is_stream_output_supported() await self.is_stream_output_supported()
except Exception as e: except Exception:
await self.logger.error(f'Error in telegram callback: {traceback.format_exc()}') await self.logger.error(f'Error in telegram callback: {traceback.format_exc()}')
self.application = ApplicationBuilder().token(self.config['token']).build() self.application = ApplicationBuilder().token(self.config['token']).build()

View File

@@ -133,7 +133,11 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter):
) )
# notify waiter # notify waiter
session = (self.webchat_group_session if isinstance(message_source, platform_events.GroupMessage) else self.webchat_person_session) session = (
self.webchat_group_session
if isinstance(message_source, platform_events.GroupMessage)
else self.webchat_person_session
)
if message_source.message_chain.message_id not in session.resp_waiters: if message_source.message_chain.message_id not in session.resp_waiters:
# session.resp_waiters[message_source.message_chain.message_id] = asyncio.Queue() # session.resp_waiters[message_source.message_chain.message_id] = asyncio.Queue()
queue = session.resp_queues[message_source.message_chain.message_id] queue = session.resp_queues[message_source.message_chain.message_id]
@@ -147,8 +151,6 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter):
# print(message_data) # print(message_data)
await queue.put(message_data) await queue.put(message_data)
return message_data.model_dump() return message_data.model_dump()
async def is_stream_output_supported(self) -> bool: async def is_stream_output_supported(self) -> bool:
@@ -186,7 +188,10 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter):
await self.logger.info('WebChat调试适配器正在停止') await self.logger.info('WebChat调试适配器正在停止')
async def send_webchat_message( async def send_webchat_message(
self, pipeline_uuid: str, session_type: str, message_chain_obj: typing.List[dict], self,
pipeline_uuid: str,
session_type: str,
message_chain_obj: typing.List[dict],
is_stream: bool = False, is_stream: bool = False,
) -> dict: ) -> dict:
self.is_stream = is_stream self.is_stream = is_stream
@@ -202,7 +207,7 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter):
if is_stream: if is_stream:
use_session.resp_queues[message_id] = asyncio.Queue() use_session.resp_queues[message_id] = asyncio.Queue()
logger.debug(f"Initialized queue for message_id: {message_id}") logger.debug(f'Initialized queue for message_id: {message_id}')
use_session.get_message_list(pipeline_uuid).append( use_session.get_message_list(pipeline_uuid).append(
WebChatMessage( WebChatMessage(

View File

@@ -1,5 +1,4 @@
import requests import requests
import websockets
import websocket import websocket
import json import json
import time import time
@@ -10,53 +9,41 @@ from libs.wechatpad_api.client import WeChatPadClient
import typing import typing
import asyncio import asyncio
import traceback import traceback
import time
import re import re
import base64 import base64
import uuid
import json
import os
import copy import copy
import datetime
import threading import threading
import quart import quart
import aiohttp
from .. import adapter from .. import adapter
from ...pipeline.longtext.strategies import forward
from ...core import app from ...core import app
from ..types import message as platform_message from ..types import message as platform_message
from ..types import events as platform_events from ..types import events as platform_events
from ..types import entities as platform_entities from ..types import entities as platform_entities
from ...utils import image
from ..logger import EventLogger from ..logger import EventLogger
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from typing import Optional, List, Tuple from typing import Optional, Tuple
from functools import partial from functools import partial
import logging import logging
class WeChatPadMessageConverter(adapter.MessageConverter):
class WeChatPadMessageConverter(adapter.MessageConverter):
def __init__(self, config: dict): def __init__(self, config: dict):
self.config = config self.config = config
self.bot = WeChatPadClient(self.config["wechatpad_url"],self.config["token"]) self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'])
self.logger = logging.getLogger("WeChatPadMessageConverter") self.logger = logging.getLogger('WeChatPadMessageConverter')
@staticmethod @staticmethod
async def yiri2target( async def yiri2target(message_chain: platform_message.MessageChain) -> list[dict]:
message_chain: platform_message.MessageChain
) -> list[dict]:
content_list = [] content_list = []
current_file_path = os.path.abspath(__file__) # current_file_path = os.path.abspath(__file__)
for component in message_chain: for component in message_chain:
if isinstance(component, platform_message.At): if isinstance(component, platform_message.At):
content_list.append({"type": "at", "target": component.target}) content_list.append({'type': 'at', 'target': component.target})
elif isinstance(component, platform_message.Plain): elif isinstance(component, platform_message.Plain):
content_list.append({"type": "text", "content": component.text}) content_list.append({'type': 'text', 'content': component.text})
elif isinstance(component, platform_message.Image): elif isinstance(component, platform_message.Image):
if component.url: if component.url:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
@@ -68,15 +55,16 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
else: else:
raise Exception('获取文件失败') raise Exception('获取文件失败')
# pass # pass
content_list.append({"type": "image", "image": base64_str}) content_list.append({'type': 'image', 'image': base64_str})
elif component.base64: elif component.base64:
content_list.append({"type": "image", "image": component.base64}) content_list.append({'type': 'image', 'image': component.base64})
elif isinstance(component, platform_message.WeChatEmoji): elif isinstance(component, platform_message.WeChatEmoji):
content_list.append( content_list.append(
{'type': 'WeChatEmoji', 'emoji_md5': component.emoji_md5, 'emoji_size': component.emoji_size}) {'type': 'WeChatEmoji', 'emoji_md5': component.emoji_md5, 'emoji_size': component.emoji_size}
)
elif isinstance(component, platform_message.Voice): elif isinstance(component, platform_message.Voice):
content_list.append({"type": "voice", "data": component.url, "duration": component.length, "forma": 0}) content_list.append({'type': 'voice', 'data': component.url, 'duration': component.length, 'forma': 0})
elif isinstance(component, platform_message.WeChatAppMsg): elif isinstance(component, platform_message.WeChatAppMsg):
content_list.append({'type': 'WeChatAppMsg', 'app_msg': component.app_msg}) content_list.append({'type': 'WeChatAppMsg', 'app_msg': component.app_msg})
elif isinstance(component, platform_message.Forward): elif isinstance(component, platform_message.Forward):
@@ -86,28 +74,23 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
return content_list return content_list
async def target2yiri(self, message: dict, bot_account_id: str) -> platform_message.MessageChain:
async def target2yiri(
self,
message: dict,
bot_account_id: str
) -> platform_message.MessageChain:
"""外部消息转平台消息""" """外部消息转平台消息"""
# 数据预处理 # 数据预处理
message_list = [] message_list = []
ats_bot = False # 是否被@ ats_bot = False # 是否被@
content = message["content"]["str"] content = message['content']['str']
content_no_preifx = content # 群消息则去掉前缀 content_no_preifx = content # 群消息则去掉前缀
is_group_message = self._is_group_message(message) is_group_message = self._is_group_message(message)
if is_group_message: if is_group_message:
ats_bot = self._ats_bot(message, bot_account_id) ats_bot = self._ats_bot(message, bot_account_id)
if "@所有人" in content: if '@所有人' in content:
message_list.append(platform_message.AtAll()) message_list.append(platform_message.AtAll())
elif ats_bot: elif ats_bot:
message_list.append(platform_message.At(target=bot_account_id)) message_list.append(platform_message.At(target=bot_account_id))
content_no_preifx, _ = self._extract_content_and_sender(content) content_no_preifx, _ = self._extract_content_and_sender(content)
msg_type = message["msg_type"] msg_type = message['msg_type']
# 映射消息类型到处理器方法 # 映射消息类型到处理器方法
handler_map = { handler_map = {
@@ -129,11 +112,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain(message_list) return platform_message.MessageChain(message_list)
async def _handler_text( async def _handler_text(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
self,
message: Optional[dict],
content_no_preifx: str
) -> platform_message.MessageChain:
"""处理文本消息 (msg_type=1)""" """处理文本消息 (msg_type=1)"""
if message and self._is_group_message(message): if message and self._is_group_message(message):
pattern = r'@\S{1,20}' pattern = r'@\S{1,20}'
@@ -141,16 +120,12 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain([platform_message.Plain(content_no_preifx)]) return platform_message.MessageChain([platform_message.Plain(content_no_preifx)])
async def _handler_image( async def _handler_image(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
self,
message: Optional[dict],
content_no_preifx: str
) -> platform_message.MessageChain:
"""处理图像消息 (msg_type=3)""" """处理图像消息 (msg_type=3)"""
try: try:
image_xml = content_no_preifx image_xml = content_no_preifx
if not image_xml: if not image_xml:
return platform_message.MessageChain([platform_message.Unknown("[图片内容为空]")]) return platform_message.MessageChain([platform_message.Unknown('[图片内容为空]')])
root = ET.fromstring(image_xml) root = ET.fromstring(image_xml)
# 提取img标签的属性 # 提取img标签的属性
@@ -160,28 +135,22 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
cdnthumburl = img_tag.get('cdnthumburl') cdnthumburl = img_tag.get('cdnthumburl')
# cdnmidimgurl = img_tag.get('cdnmidimgurl') # cdnmidimgurl = img_tag.get('cdnmidimgurl')
image_data = self.bot.cdn_download(aeskey=aeskey, file_type=1, file_url=cdnthumburl) image_data = self.bot.cdn_download(aeskey=aeskey, file_type=1, file_url=cdnthumburl)
if image_data["Data"]['FileData'] == '': if image_data['Data']['FileData'] == '':
image_data = self.bot.cdn_download(aeskey=aeskey, file_type=2, file_url=cdnthumburl) image_data = self.bot.cdn_download(aeskey=aeskey, file_type=2, file_url=cdnthumburl)
base64_str = image_data["Data"]['FileData'] base64_str = image_data['Data']['FileData']
# self.logger.info(f"data:image/png;base64,{base64_str}") # self.logger.info(f"data:image/png;base64,{base64_str}")
elements = [ elements = [
platform_message.Image(base64=f"data:image/png;base64,{base64_str}"), platform_message.Image(base64=f'data:image/png;base64,{base64_str}'),
# platform_message.WeChatForwardImage(xml_data=image_xml) # 微信消息转发 # platform_message.WeChatForwardImage(xml_data=image_xml) # 微信消息转发
] ]
return platform_message.MessageChain(elements) return platform_message.MessageChain(elements)
except Exception as e: except Exception as e:
self.logger.error(f"处理图片失败: {str(e)}") self.logger.error(f'处理图片失败: {str(e)}')
return platform_message.MessageChain([platform_message.Unknown("[图片处理失败]")]) return platform_message.MessageChain([platform_message.Unknown('[图片处理失败]')])
async def _handler_voice( async def _handler_voice(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
self,
message: Optional[dict],
content_no_preifx: str
) -> platform_message.MessageChain:
"""处理语音消息 (msg_type=34)""" """处理语音消息 (msg_type=34)"""
message_List = [] message_List = []
try: try:
@@ -197,39 +166,33 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
bufid = voicemsg.get('bufid') bufid = voicemsg.get('bufid')
length = voicemsg.get('voicelength') length = voicemsg.get('voicelength')
voice_data = self.bot.get_msg_voice(buf_id=str(bufid), length=int(length), msgid=str(new_msg_id)) voice_data = self.bot.get_msg_voice(buf_id=str(bufid), length=int(length), msgid=str(new_msg_id))
audio_base64 = voice_data["Data"]['Base64'] audio_base64 = voice_data['Data']['Base64']
# 验证语音数据有效性 # 验证语音数据有效性
if not audio_base64: if not audio_base64:
message_List.append(platform_message.Unknown(text="[语音内容为空]")) message_List.append(platform_message.Unknown(text='[语音内容为空]'))
return platform_message.MessageChain(message_List) return platform_message.MessageChain(message_List)
# 转换为平台支持的语音格式(如 Silk 格式) # 转换为平台支持的语音格式(如 Silk 格式)
voice_element = platform_message.Voice( voice_element = platform_message.Voice(base64=f'data:audio/silk;base64,{audio_base64}')
base64=f"data:audio/silk;base64,{audio_base64}"
)
message_List.append(voice_element) message_List.append(voice_element)
except KeyError as e: except KeyError as e:
self.logger.error(f"语音数据字段缺失: {str(e)}") self.logger.error(f'语音数据字段缺失: {str(e)}')
message_List.append(platform_message.Unknown(text="[语音数据解析失败]")) message_List.append(platform_message.Unknown(text='[语音数据解析失败]'))
except Exception as e: except Exception as e:
self.logger.error(f"处理语音消息异常: {str(e)}") self.logger.error(f'处理语音消息异常: {str(e)}')
message_List.append(platform_message.Unknown(text="[语音处理失败]")) message_List.append(platform_message.Unknown(text='[语音处理失败]'))
return platform_message.MessageChain(message_List) return platform_message.MessageChain(message_List)
async def _handler_compound( async def _handler_compound(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
self,
message: Optional[dict],
content_no_preifx: str
) -> platform_message.MessageChain:
"""处理复合消息 (msg_type=49),根据子类型分派""" """处理复合消息 (msg_type=49),根据子类型分派"""
try: try:
xml_data = ET.fromstring(content_no_preifx) xml_data = ET.fromstring(content_no_preifx)
appmsg_data = xml_data.find('.//appmsg') appmsg_data = xml_data.find('.//appmsg')
if appmsg_data: if appmsg_data:
data_type = appmsg_data.findtext('.//type', "") data_type = appmsg_data.findtext('.//type', '')
# 二次分派处理器 # 二次分派处理器
sub_handler_map = { sub_handler_map = {
'57': self._handler_compound_quote, '57': self._handler_compound_quote,
@@ -238,9 +201,9 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
'74': self._handler_compound_file, '74': self._handler_compound_file,
'33': self._handler_compound_mini_program, '33': self._handler_compound_mini_program,
'36': self._handler_compound_mini_program, '36': self._handler_compound_mini_program,
'2000': partial(self._handler_compound_unsupported, text="[转账消息]"), '2000': partial(self._handler_compound_unsupported, text='[转账消息]'),
'2001': partial(self._handler_compound_unsupported, text="[红包消息]"), '2001': partial(self._handler_compound_unsupported, text='[红包消息]'),
'51': partial(self._handler_compound_unsupported, text="[视频号消息]"), '51': partial(self._handler_compound_unsupported, text='[视频号消息]'),
} }
handler = sub_handler_map.get(data_type, self._handler_compound_unsupported) handler = sub_handler_map.get(data_type, self._handler_compound_unsupported)
@@ -251,56 +214,51 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
else: else:
return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)]) return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)])
except Exception as e: except Exception as e:
self.logger.error(f"解析复合消息失败: {str(e)}") self.logger.error(f'解析复合消息失败: {str(e)}')
return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)]) return platform_message.MessageChain([platform_message.Unknown(text=content_no_preifx)])
async def _handler_compound_quote( async def _handler_compound_quote(
self, self, message: Optional[dict], xml_data: ET.Element
message: Optional[dict],
xml_data: ET.Element
) -> platform_message.MessageChain: ) -> platform_message.MessageChain:
"""处理引用消息 (data_type=57)""" """处理引用消息 (data_type=57)"""
message_list = [] message_list = []
# self.logger.info("_handler_compound_quote", ET.tostring(xml_data, encoding='unicode')) # self.logger.info("_handler_compound_quote", ET.tostring(xml_data, encoding='unicode'))
appmsg_data = xml_data.find('.//appmsg') appmsg_data = xml_data.find('.//appmsg')
quote_data = "" # 引用原文 quote_data = '' # 引用原文
quote_id = None # 引用消息的原发送者 # quote_id = None # 引用消息的原发送者
tousername = None # 接收方: 所属微信的wxid # tousername = None # 接收方: 所属微信的wxid
user_data = "" # 用户消息 user_data = '' # 用户消息
sender_id = xml_data.findtext('.//fromusername') # 发送方:单聊用户/群member sender_id = xml_data.findtext('.//fromusername') # 发送方:单聊用户/群member
# 引用消息转发 # 引用消息转发
if appmsg_data: if appmsg_data:
user_data = appmsg_data.findtext('.//title') or "" user_data = appmsg_data.findtext('.//title') or ''
quote_data = appmsg_data.find('.//refermsg').findtext('.//content') quote_data = appmsg_data.find('.//refermsg').findtext('.//content')
quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') # quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr')
message_list.append( message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg_data, encoding='unicode')))
platform_message.WeChatAppMsg( # if message:
app_msg=ET.tostring(appmsg_data, encoding='unicode')) # tousername = message['to_user_name']['str']
)
if message:
tousername = message['to_user_name']["str"]
if quote_data: if quote_data:
quote_data_message_list = platform_message.MessageChain() quote_data_message_list = platform_message.MessageChain()
# 文本消息 # 文本消息
try: try:
if "<msg>" not in quote_data: if '<msg>' not in quote_data:
quote_data_message_list.append(platform_message.Plain(quote_data)) quote_data_message_list.append(platform_message.Plain(quote_data))
else: else:
# 引用消息展开 # 引用消息展开
quote_data_xml = ET.fromstring(quote_data) quote_data_xml = ET.fromstring(quote_data)
if quote_data_xml.find("img"): if quote_data_xml.find('img'):
quote_data_message_list.extend(await self._handler_image(None, quote_data)) quote_data_message_list.extend(await self._handler_image(None, quote_data))
elif quote_data_xml.find("voicemsg"): elif quote_data_xml.find('voicemsg'):
quote_data_message_list.extend(await self._handler_voice(None, quote_data)) quote_data_message_list.extend(await self._handler_voice(None, quote_data))
elif quote_data_xml.find("videomsg"): elif quote_data_xml.find('videomsg'):
quote_data_message_list.extend(await self._handler_default(None, quote_data)) # 先不处理 quote_data_message_list.extend(await self._handler_default(None, quote_data)) # 先不处理
else: else:
# appmsg # appmsg
quote_data_message_list.extend(await self._handler_compound(None, quote_data)) quote_data_message_list.extend(await self._handler_compound(None, quote_data))
except Exception as e: except Exception as e:
self.logger.error(f"处理引用消息异常 expcetion:{e}") self.logger.error(f'处理引用消息异常 expcetion:{e}')
quote_data_message_list.append(platform_message.Plain(quote_data)) quote_data_message_list.append(platform_message.Plain(quote_data))
message_list.append( message_list.append(
platform_message.Quote( platform_message.Quote(
@@ -315,15 +273,11 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
return platform_message.MessageChain(message_list) return platform_message.MessageChain(message_list)
async def _handler_compound_file( async def _handler_compound_file(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain:
self,
message: dict,
xml_data: ET.Element
) -> platform_message.MessageChain:
"""处理文件消息 (data_type=6)""" """处理文件消息 (data_type=6)"""
file_data = xml_data.find('.//appmsg') file_data = xml_data.find('.//appmsg')
if file_data.findtext('.//type', "") == "74": if file_data.findtext('.//type', '') == '74':
return None return None
else: else:
@@ -346,22 +300,21 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
file_data = self.bot.cdn_download(aeskey=aeskey, file_type=5, file_url=cdnthumburl) file_data = self.bot.cdn_download(aeskey=aeskey, file_type=5, file_url=cdnthumburl)
file_base64 = file_data["Data"]['FileData'] file_base64 = file_data['Data']['FileData']
# print(file_data) # print(file_data)
file_size = file_data["Data"]['TotalSize'] file_size = file_data['Data']['TotalSize']
# print(file_base64) # print(file_base64)
return platform_message.MessageChain([ return platform_message.MessageChain(
platform_message.WeChatFile(file_id=file_id, file_name=file_name, file_size=file_size, [
file_base64=file_base64), platform_message.WeChatFile(
platform_message.WeChatForwardFile(xml_data=xml_data_str) file_id=file_id, file_name=file_name, file_size=file_size, file_base64=file_base64
]) ),
platform_message.WeChatForwardFile(xml_data=xml_data_str),
]
)
async def _handler_compound_link( async def _handler_compound_link(self, message: dict, xml_data: ET.Element) -> platform_message.MessageChain:
self,
message: dict,
xml_data: ET.Element
) -> platform_message.MessageChain:
"""处理链接消息(如公众号文章、外部网页)""" """处理链接消息(如公众号文章、外部网页)"""
message_list = [] message_list = []
try: try:
@@ -374,56 +327,38 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
link_title=appmsg.findtext('title', ''), link_title=appmsg.findtext('title', ''),
link_desc=appmsg.findtext('des', ''), link_desc=appmsg.findtext('des', ''),
link_url=appmsg.findtext('url', ''), link_url=appmsg.findtext('url', ''),
link_thumb_url=appmsg.findtext("thumburl", '') # 这个字段拿不到 link_thumb_url=appmsg.findtext('thumburl', ''), # 这个字段拿不到
) )
) )
# 还没有发链接的接口, 暂时还需要自己构造appmsg, 先用WeChatAppMsg。 # 还没有发链接的接口, 暂时还需要自己构造appmsg, 先用WeChatAppMsg。
message_list.append( message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg, encoding='unicode')))
platform_message.WeChatAppMsg(
app_msg=ET.tostring(appmsg, encoding='unicode')
)
)
except Exception as e: except Exception as e:
self.logger.error(f"解析链接消息失败: {str(e)}") self.logger.error(f'解析链接消息失败: {str(e)}')
return platform_message.MessageChain(message_list) return platform_message.MessageChain(message_list)
async def _handler_compound_mini_program( async def _handler_compound_mini_program(
self, self, message: dict, xml_data: ET.Element
message: dict,
xml_data: ET.Element
) -> platform_message.MessageChain: ) -> platform_message.MessageChain:
"""处理小程序消息(如小程序卡片、服务通知)""" """处理小程序消息(如小程序卡片、服务通知)"""
xml_data_str = ET.tostring(xml_data, encoding='unicode') xml_data_str = ET.tostring(xml_data, encoding='unicode')
return platform_message.MessageChain([ return platform_message.MessageChain([platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)])
platform_message.WeChatForwardMiniPrograms(xml_data=xml_data_str)
])
async def _handler_default( async def _handler_default(self, message: Optional[dict], content_no_preifx: str) -> platform_message.MessageChain:
self,
message: Optional[dict],
content_no_preifx: str
) -> platform_message.MessageChain:
"""处理未知消息类型""" """处理未知消息类型"""
if message: if message:
msg_type = message["msg_type"] msg_type = message['msg_type']
else: else:
msg_type = "" msg_type = ''
return platform_message.MessageChain([ return platform_message.MessageChain([platform_message.Unknown(text=f'[未知消息类型 msg_type:{msg_type}]')])
platform_message.Unknown(text=f"[未知消息类型 msg_type:{msg_type}]")
])
def _handler_compound_unsupported( def _handler_compound_unsupported(
self, self, message: dict, xml_data: str, text: Optional[str] = None
message: dict,
xml_data: str,
text: Optional[str] = None
) -> platform_message.MessageChain: ) -> platform_message.MessageChain:
"""处理未支持复合消息类型(msg_type=49)子类型""" """处理未支持复合消息类型(msg_type=49)子类型"""
if not text: if not text:
text = f"[xml_data={xml_data}]" text = f'[xml_data={xml_data}]'
content_list = [] content_list = []
content_list.append( content_list.append(platform_message.Unknown(text=f'[处理未支持复合消息类型[msg_type=49]|{text}'))
platform_message.Unknown(text=f"[处理未支持复合消息类型[msg_type=49]|{text}"))
return platform_message.MessageChain(content_list) return platform_message.MessageChain(content_list)
@@ -432,7 +367,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
ats_bot = False ats_bot = False
try: try:
to_user_name = message['to_user_name']['str'] # 接收方: 所属微信的wxid to_user_name = message['to_user_name']['str'] # 接收方: 所属微信的wxid
raw_content = message["content"]["str"] # 原始消息内容 raw_content = message['content']['str'] # 原始消息内容
content_no_prefix, _ = self._extract_content_and_sender(raw_content) content_no_prefix, _ = self._extract_content_and_sender(raw_content)
# 直接艾特机器人这个有bug当被引用的消息里面有@bot,会套娃 # 直接艾特机器人这个有bug当被引用的消息里面有@bot,会套娃
# ats_bot = ats_bot or (f"@{bot_account_id}" in content_no_prefix) # ats_bot = ats_bot or (f"@{bot_account_id}" in content_no_prefix)
@@ -443,7 +378,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
msg_source = message.get('msg_source', '') or '' msg_source = message.get('msg_source', '') or ''
if len(msg_source) > 0: if len(msg_source) > 0:
msg_source_data = ET.fromstring(msg_source) msg_source_data = ET.fromstring(msg_source)
at_user_list = msg_source_data.findtext("atuserlist") or "" at_user_list = msg_source_data.findtext('atuserlist') or ''
ats_bot = ats_bot or (to_user_name in at_user_list) ats_bot = ats_bot or (to_user_name in at_user_list)
# 引用bot # 引用bot
if message.get('msg_type', 0) == 49: if message.get('msg_type', 0) == 49:
@@ -454,7 +389,7 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') # 引用消息的原发送者 quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr') # 引用消息的原发送者
ats_bot = ats_bot or (quote_id == tousername) ats_bot = ats_bot or (quote_id == tousername)
except Exception as e: except Exception as e:
self.logger.error(f"_ats_bot got except: {e}") self.logger.error(f'_ats_bot got except: {e}')
finally: finally:
return ats_bot return ats_bot
@@ -463,47 +398,41 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
try: try:
# 检查消息开头,如果有 wxid_sbitaz0mt65n22:\n 则删掉 # 检查消息开头,如果有 wxid_sbitaz0mt65n22:\n 则删掉
# add: 有些用户的wxid不是上述格式。换成user_name: # add: 有些用户的wxid不是上述格式。换成user_name:
regex = re.compile(r"^[a-zA-Z0-9_\-]{5,20}:") regex = re.compile(r'^[a-zA-Z0-9_\-]{5,20}:')
line_split = raw_content.split("\n") line_split = raw_content.split('\n')
if len(line_split) > 0 and regex.match(line_split[0]): if len(line_split) > 0 and regex.match(line_split[0]):
raw_content = "\n".join(line_split[1:]) raw_content = '\n'.join(line_split[1:])
sender_id = line_split[0].strip(":") sender_id = line_split[0].strip(':')
return raw_content, sender_id return raw_content, sender_id
except Exception as e: except Exception as e:
self.logger.error(f"_extract_content_and_sender got except: {e}") self.logger.error(f'_extract_content_and_sender got except: {e}')
finally: finally:
return raw_content, None return raw_content, None
# 是否是群消息 # 是否是群消息
def _is_group_message(self, message: dict) -> bool: def _is_group_message(self, message: dict) -> bool:
from_user_name = message['from_user_name']['str'] from_user_name = message['from_user_name']['str']
return from_user_name.endswith("@chatroom") return from_user_name.endswith('@chatroom')
class WeChatPadEventConverter(adapter.EventConverter): class WeChatPadEventConverter(adapter.EventConverter):
def __init__(self, config: dict): def __init__(self, config: dict):
self.config = config self.config = config
self.message_converter = WeChatPadMessageConverter(config) self.message_converter = WeChatPadMessageConverter(config)
self.logger = logging.getLogger("WeChatPadEventConverter") self.logger = logging.getLogger('WeChatPadEventConverter')
@staticmethod @staticmethod
async def yiri2target( async def yiri2target(event: platform_events.MessageEvent) -> dict:
event: platform_events.MessageEvent
) -> dict:
pass pass
async def target2yiri( async def target2yiri(self, event: dict, bot_account_id: str) -> platform_events.MessageEvent:
self,
event: dict,
bot_account_id: str
) -> platform_events.MessageEvent:
# 排除公众号以及微信团队消息 # 排除公众号以及微信团队消息
if event['from_user_name']['str'].startswith('gh_') \ if (
or event['from_user_name']['str']=='weixin'\ event['from_user_name']['str'].startswith('gh_')
or event['from_user_name']['str'] == "newsapp"\ or event['from_user_name']['str'] == 'weixin'
or event['from_user_name']['str'] == self.config["wxid"]: or event['from_user_name']['str'] == 'newsapp'
or event['from_user_name']['str'] == self.config['wxid']
):
return None return None
message_chain = await self.message_converter.target2yiri(copy.deepcopy(event), bot_account_id) message_chain = await self.message_converter.target2yiri(copy.deepcopy(event), bot_account_id)
@@ -512,7 +441,7 @@ class WeChatPadEventConverter(adapter.EventConverter):
if '@chatroom' in event['from_user_name']['str']: if '@chatroom' in event['from_user_name']['str']:
# 找出开头的 wxid_ 字符串,以:结尾 # 找出开头的 wxid_ 字符串,以:结尾
sender_wxid = event['content']['str'].split(":")[0] sender_wxid = event['content']['str'].split(':')[0]
return platform_events.GroupMessage( return platform_events.GroupMessage(
sender=platform_entities.GroupMember( sender=platform_entities.GroupMember(
@@ -524,13 +453,13 @@ class WeChatPadEventConverter(adapter.EventConverter):
name=event['from_user_name']['str'], name=event['from_user_name']['str'],
permission=platform_entities.Permission.Member, permission=platform_entities.Permission.Member,
), ),
special_title="", special_title='',
join_timestamp=0, join_timestamp=0,
last_speak_timestamp=0, last_speak_timestamp=0,
mute_time_remaining=0, mute_time_remaining=0,
), ),
message_chain=message_chain, message_chain=message_chain,
time=event["create_time"], time=event['create_time'],
source_platform_object=event, source_platform_object=event,
) )
else: else:
@@ -541,13 +470,13 @@ class WeChatPadEventConverter(adapter.EventConverter):
remark='', remark='',
), ),
message_chain=message_chain, message_chain=message_chain,
time=event["create_time"], time=event['create_time'],
source_platform_object=event, source_platform_object=event,
) )
class WeChatPadAdapter(adapter.MessagePlatformAdapter): class WeChatPadAdapter(adapter.MessagePlatformAdapter):
name: str = "WeChatPad" # 定义适配器名称 name: str = 'WeChatPad' # 定义适配器名称
bot: WeChatPadClient bot: WeChatPadClient
quart_app: quart.Quart quart_app: quart.Quart
@@ -580,27 +509,21 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
# self.ap.logger.debug(f"Gewechat callback event: {data}") # self.ap.logger.debug(f"Gewechat callback event: {data}")
# print(data) # print(data)
try: try:
event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id) event = await self.event_converter.target2yiri(data.copy(), self.bot_account_id)
except Exception as e: except Exception:
await self.logger.error(f"Error in wechatpad callback: {traceback.format_exc()}") await self.logger.error(f'Error in wechatpad callback: {traceback.format_exc()}')
if event.__class__ in self.listeners: if event.__class__ in self.listeners:
await self.listeners[event.__class__](event, self) await self.listeners[event.__class__](event, self)
return 'ok' return 'ok'
async def _handle_message(self, message: platform_message.MessageChain, target_id: str):
async def _handle_message(
self,
message: platform_message.MessageChain,
target_id: str
):
"""统一消息处理核心逻辑""" """统一消息处理核心逻辑"""
content_list = await self.message_converter.yiri2target(message) content_list = await self.message_converter.yiri2target(message)
# print(content_list) # print(content_list)
at_targets = [item["target"] for item in content_list if item["type"] == "at"] at_targets = [item['target'] for item in content_list if item['type'] == 'at']
# print(at_targets) # print(at_targets)
# 处理@逻辑 # 处理@逻辑
at_targets = at_targets or [] at_targets = at_targets or []
@@ -608,7 +531,7 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
if at_targets: if at_targets:
member_info = self.bot.get_chatroom_member_detail( member_info = self.bot.get_chatroom_member_detail(
target_id, target_id,
)["Data"]["member_data"]["chatroom_member_list"] )['Data']['member_data']['chatroom_member_list']
# 处理消息组件 # 处理消息组件
for msg in content_list: for msg in content_list:
@@ -616,55 +539,43 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
if msg['type'] == 'text' and at_targets: if msg['type'] == 'text' and at_targets:
at_nick_name_list = [] at_nick_name_list = []
for member in member_info: for member in member_info:
if member["user_name"] in at_targets: if member['user_name'] in at_targets:
at_nick_name_list.append(f'@{member["nick_name"]}') at_nick_name_list.append(f'@{member["nick_name"]}')
msg['content'] = f'{" ".join(at_nick_name_list)} {msg["content"]}' msg['content'] = f'{" ".join(at_nick_name_list)} {msg["content"]}'
# 统一消息派发 # 统一消息派发
handler_map = { handler_map = {
'text': lambda msg: self.bot.send_text_message( 'text': lambda msg: self.bot.send_text_message(
to_wxid=target_id, to_wxid=target_id, message=msg['content'], ats=at_targets
message=msg['content'],
ats=at_targets
), ),
'image': lambda msg: self.bot.send_image_message( 'image': lambda msg: self.bot.send_image_message(
to_wxid=target_id, to_wxid=target_id, img_url=msg['image'], ats=at_targets
img_url=msg["image"],
ats = at_targets
), ),
'WeChatEmoji': lambda msg: self.bot.send_emoji_message( 'WeChatEmoji': lambda msg: self.bot.send_emoji_message(
to_wxid=target_id, to_wxid=target_id, emoji_md5=msg['emoji_md5'], emoji_size=msg['emoji_size']
emoji_md5=msg['emoji_md5'],
emoji_size=msg['emoji_size']
), ),
'voice': lambda msg: self.bot.send_voice_message( 'voice': lambda msg: self.bot.send_voice_message(
to_wxid=target_id, to_wxid=target_id,
voice_data=msg['data'], voice_data=msg['data'],
voice_duration=msg["duration"], voice_duration=msg['duration'],
voice_forma=msg["forma"], voice_forma=msg['forma'],
), ),
'WeChatAppMsg': lambda msg: self.bot.send_app_message( 'WeChatAppMsg': lambda msg: self.bot.send_app_message(
to_wxid=target_id, to_wxid=target_id,
app_message=msg['app_msg'], app_message=msg['app_msg'],
type=0, type=0,
), ),
'at': lambda msg: None 'at': lambda msg: None,
} }
if handler := handler_map.get(msg['type']): if handler := handler_map.get(msg['type']):
handler(msg) handler(msg)
# self.ap.logger.warning(f"未处理的消息类型: {ret}") # self.ap.logger.warning(f"未处理的消息类型: {ret}")
else: else:
self.ap.logger.warning(f"未处理的消息类型: {msg['type']}") self.ap.logger.warning(f'未处理的消息类型: {msg["type"]}')
continue continue
async def send_message( async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
self,
target_type: str,
target_id: str,
message: platform_message.MessageChain
):
"""主动发送消息""" """主动发送消息"""
return await self._handle_message(message, target_id) return await self._handle_message(message, target_id)
@@ -672,7 +583,7 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
self, self,
message_source: platform_events.MessageEvent, message_source: platform_events.MessageEvent,
message: platform_message.MessageChain, message: platform_message.MessageChain,
quote_origin: bool = False quote_origin: bool = False,
): ):
"""回复消息""" """回复消息"""
if message_source.source_platform_object: if message_source.source_platform_object:
@@ -685,56 +596,47 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None] callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
): ):
self.listeners[event_type] = callback self.listeners[event_type] = callback
def unregister_listener( def unregister_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None] callback: typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
): ):
pass pass
async def run_async(self): async def run_async(self):
if not self.config['admin_key'] and not self.config['token']:
if not self.config["admin_key"] and not self.config["token"]: raise RuntimeError('无wechatpad管理密匙请填入配置文件后重启')
raise RuntimeError("无wechatpad管理密匙请填入配置文件后重启")
else: else:
if self.config["token"]: if self.config['token']:
self.bot = WeChatPadClient( self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'])
self.config['wechatpad_url'],
self.config["token"]
)
data = self.bot.get_login_status() data = self.bot.get_login_status()
self.ap.logger.info(data) self.ap.logger.info(data)
if data["Code"] == 300 and data["Text"] == "你已退出微信": if data['Code'] == 300 and data['Text'] == '你已退出微信':
response = requests.post( response = requests.post(
f"{self.config['wechatpad_url']}/admin/GenAuthKey1?key={self.config['admin_key']}", f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}',
json={"Count": 1, "Days": 365} json={'Count': 1, 'Days': 365},
) )
if response.status_code != 200: if response.status_code != 200:
raise Exception(f"获取token失败: {response.text}") raise Exception(f'获取token失败: {response.text}')
self.config["token"] = response.json()["Data"][0] self.config['token'] = response.json()['Data'][0]
elif not self.config["token"]: elif not self.config['token']:
response = requests.post( response = requests.post(
f"{self.config['wechatpad_url']}/admin/GenAuthKey1?key={self.config['admin_key']}", f'{self.config["wechatpad_url"]}/admin/GenAuthKey1?key={self.config["admin_key"]}',
json={"Count": 1, "Days": 365} json={'Count': 1, 'Days': 365},
) )
if response.status_code != 200: if response.status_code != 200:
raise Exception(f"获取token失败: {response.text}") raise Exception(f'获取token失败: {response.text}')
self.config["token"] = response.json()["Data"][0] self.config['token'] = response.json()['Data'][0]
self.bot = WeChatPadClient( self.bot = WeChatPadClient(self.config['wechatpad_url'], self.config['token'], logger=self.logger)
self.config['wechatpad_url'], self.ap.logger.info(self.config['token'])
self.config["token"],
logger=self.logger
)
self.ap.logger.info(self.config["token"])
thread_1 = threading.Event() thread_1 = threading.Event()
def wechat_login_process(): def wechat_login_process():
# 不登录这些先注释掉避免登陆态尝试拉qrcode。 # 不登录这些先注释掉避免登陆态尝试拉qrcode。
# login_data =self.bot.get_login_qr() # login_data =self.bot.get_login_qr()
@@ -742,67 +644,54 @@ class WeChatPadAdapter(adapter.MessagePlatformAdapter):
# url = login_data['Data']["QrCodeUrl"] # url = login_data['Data']["QrCodeUrl"]
# self.ap.logger.info(login_data) # self.ap.logger.info(login_data)
profile = self.bot.get_profile() profile = self.bot.get_profile()
self.ap.logger.info(profile) self.ap.logger.info(profile)
self.bot_account_id = profile["Data"]["userInfo"]["nickName"]["str"] self.bot_account_id = profile['Data']['userInfo']['nickName']['str']
self.config["wxid"] = profile["Data"]["userInfo"]["userName"]["str"] self.config['wxid'] = profile['Data']['userInfo']['userName']['str']
thread_1.set() thread_1.set()
# asyncio.create_task(wechat_login_process) # asyncio.create_task(wechat_login_process)
threading.Thread(target=wechat_login_process).start() threading.Thread(target=wechat_login_process).start()
def connect_websocket_sync() -> None: def connect_websocket_sync() -> None:
thread_1.wait() thread_1.wait()
uri = f"{self.config['wechatpad_ws']}/GetSyncMsg?key={self.config['token']}" uri = f'{self.config["wechatpad_ws"]}/GetSyncMsg?key={self.config["token"]}'
self.ap.logger.info(f"Connecting to WebSocket: {uri}") self.ap.logger.info(f'Connecting to WebSocket: {uri}')
def on_message(ws, message): def on_message(ws, message):
try: try:
data = json.loads(message) data = json.loads(message)
self.ap.logger.debug(f"Received message: {data}") self.ap.logger.debug(f'Received message: {data}')
# 这里需要确保ws_message是同步的或者使用asyncio.run调用异步方法 # 这里需要确保ws_message是同步的或者使用asyncio.run调用异步方法
asyncio.run(self.ws_message(data)) asyncio.run(self.ws_message(data))
except json.JSONDecodeError: except json.JSONDecodeError:
self.ap.logger.error(f"Non-JSON message: {message[:100]}...") self.ap.logger.error(f'Non-JSON message: {message[:100]}...')
def on_error(ws, error): def on_error(ws, error):
self.ap.logger.error(f"WebSocket error: {str(error)[:200]}") self.ap.logger.error(f'WebSocket error: {str(error)[:200]}')
def on_close(ws, close_status_code, close_msg): def on_close(ws, close_status_code, close_msg):
self.ap.logger.info("WebSocket closed, reconnecting...") self.ap.logger.info('WebSocket closed, reconnecting...')
time.sleep(5) time.sleep(5)
connect_websocket_sync() # 自动重连 connect_websocket_sync() # 自动重连
def on_open(ws): def on_open(ws):
self.ap.logger.info("WebSocket connected successfully!") self.ap.logger.info('WebSocket connected successfully!')
ws = websocket.WebSocketApp( ws = websocket.WebSocketApp(
uri, uri, on_message=on_message, on_error=on_error, on_close=on_close, on_open=on_open
on_message=on_message,
on_error=on_error,
on_close=on_close,
on_open=on_open
)
ws.run_forever(
ping_interval=60,
ping_timeout=20
) )
ws.run_forever(ping_interval=60, ping_timeout=20)
# 直接调用同步版本(会阻塞) # 直接调用同步版本(会阻塞)
# connect_websocket_sync() # connect_websocket_sync()
# 这行代码会在WebSocket连接断开后才会执行 # 这行代码会在WebSocket连接断开后才会执行
# self.ap.logger.info("WebSocket client thread started") # self.ap.logger.info("WebSocket client thread started")
thread = threading.Thread( thread = threading.Thread(target=connect_websocket_sync, name='WebSocketClientThread', daemon=True)
target=connect_websocket_sync,
name="WebSocketClientThread",
daemon=True
)
thread.start() thread.start()
self.ap.logger.info("WebSocket client thread started") self.ap.logger.info('WebSocket client thread started')
async def kill(self) -> bool: async def kill(self) -> bool:
pass pass

View File

@@ -157,7 +157,7 @@ class WecomAdapter(adapter.MessagePlatformAdapter):
token=config['token'], token=config['token'],
EncodingAESKey=config['EncodingAESKey'], EncodingAESKey=config['EncodingAESKey'],
contacts_secret=config['contacts_secret'], contacts_secret=config['contacts_secret'],
logger=self.logger logger=self.logger,
) )
async def reply_message( async def reply_message(
@@ -201,8 +201,8 @@ class WecomAdapter(adapter.MessagePlatformAdapter):
self.bot_account_id = event.receiver_id self.bot_account_id = event.receiver_id
try: try:
return await callback(await self.event_converter.target2yiri(event), self) return await callback(await self.event_converter.target2yiri(event), self)
except Exception as e: except Exception:
await self.logger.error(f"Error in wecom callback: {traceback.format_exc()}") await self.logger.error(f'Error in wecom callback: {traceback.format_exc()}')
if event_type == platform_events.FriendMessage: if event_type == platform_events.FriendMessage:
self.bot.on_message('text')(on_message) self.bot.on_message('text')(on_message)

View File

@@ -145,7 +145,7 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter):
secret=config['secret'], secret=config['secret'],
token=config['token'], token=config['token'],
EncodingAESKey=config['EncodingAESKey'], EncodingAESKey=config['EncodingAESKey'],
logger=self.logger logger=self.logger,
) )
async def reply_message( async def reply_message(
@@ -178,8 +178,8 @@ class WecomCSAdapter(adapter.MessagePlatformAdapter):
self.bot_account_id = event.receiver_id self.bot_account_id = event.receiver_id
try: try:
return await callback(await self.event_converter.target2yiri(event), self) return await callback(await self.event_converter.target2yiri(event), self)
except Exception as e: except Exception:
await self.logger.error(f"Error in wecomcs callback: {traceback.format_exc()}") await self.logger.error(f'Error in wecomcs callback: {traceback.format_exc()}')
if event_type == platform_events.FriendMessage: if event_type == platform_events.FriendMessage:
self.bot.on_message('text')(on_message) self.bot.on_message('text')(on_message)

View File

@@ -812,12 +812,14 @@ class File(MessageComponent):
def __str__(self): def __str__(self):
return f'[文件]{self.name}' return f'[文件]{self.name}'
class Face(MessageComponent): class Face(MessageComponent):
"""系统表情 """系统表情
此处将超级表情骰子/划拳一同归类于face 此处将超级表情骰子/划拳一同归类于face
当face_type为rps(划拳)时 face_id 对应的是手势 当face_type为rps(划拳)时 face_id 对应的是手势
当face_type为dice(骰子)时 face_id 对应的是点数 当face_type为dice(骰子)时 face_id 对应的是点数
""" """
type: str = 'Face' type: str = 'Face'
"""表情类型""" """表情类型"""
face_type: str = 'face' face_type: str = 'face'
@@ -834,15 +836,15 @@ class Face(MessageComponent):
elif self.face_type == 'rps': elif self.face_type == 'rps':
return f'[表情]{self.face_name}({self.rps_data(self.face_id)})' return f'[表情]{self.face_name}({self.rps_data(self.face_id)})'
def rps_data(self, face_id): def rps_data(self, face_id):
rps_dict = { rps_dict = {
1 : "", 1: '',
2 : "剪刀", 2: '剪刀',
3 : "石头", 3: '石头',
} }
return rps_dict[face_id] return rps_dict[face_id]
# ================ 个人微信专用组件 ================ # ================ 个人微信专用组件 ================
@@ -971,5 +973,6 @@ class WeChatFile(MessageComponent):
"""文件地址""" """文件地址"""
file_base64: str = '' file_base64: str = ''
"""base64""" """base64"""
def __str__(self): def __str__(self):
return f'[文件]{self.file_name}' return f'[文件]{self.file_name}'

View File

@@ -127,6 +127,7 @@ class Message(pydantic.BaseModel):
class MessageChunk(pydantic.BaseModel): class MessageChunk(pydantic.BaseModel):
"""消息""" """消息"""
resp_message_id: typing.Optional[str] = None resp_message_id: typing.Optional[str] = None
"""消息id""" """消息id"""
@@ -210,6 +211,7 @@ class ToolCallChunk(pydantic.BaseModel):
function: FunctionCall function: FunctionCall
"""函数调用""" """函数调用"""
class Prompt(pydantic.BaseModel): class Prompt(pydantic.BaseModel):
"""供AI使用的Prompt""" """供AI使用的Prompt"""

View File

@@ -71,7 +71,7 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.
Returns: Returns:
llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk]: 返回消息对象 llm_entities.Message: 返回消息对象
""" """
pass pass
@@ -82,7 +82,6 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
model: RuntimeLLMModel, model: RuntimeLLMModel,
messages: typing.List[llm_entities.Message], messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None, funcs: typing.List[tools_entities.LLMFunction] = None,
stream: bool = False,
extra_args: dict[str, typing.Any] = {}, extra_args: dict[str, typing.Any] = {},
) -> llm_entities.MessageChunk: ) -> llm_entities.MessageChunk:
"""调用API """调用API
@@ -94,6 +93,6 @@ class LLMAPIRequester(metaclass=abc.ABCMeta):
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.
Returns: Returns:
llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk]: 返回消息对象 typing.AsyncGenerator[llm_entities.MessageChunk]: 返回消息对象
""" """
pass pass

View File

@@ -8,7 +8,7 @@ import openai.types.chat.chat_completion as chat_completion
import httpx import httpx
from .. import errors, requester from .. import errors, requester
from ....core import entities as core_entities, app from ....core import entities as core_entities
from ... import entities as llm_entities from ... import entities as llm_entities
from ...tools import entities as tools_entities from ...tools import entities as tools_entities
@@ -129,12 +129,10 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
req_messages: list[dict], req_messages: list[dict],
use_model: requester.RuntimeLLMModel, use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None, use_funcs: list[tools_entities.LLMFunction] = None,
stream: bool = False,
extra_args: dict[str, typing.Any] = {}, extra_args: dict[str, typing.Any] = {},
) -> llm_entities.MessageChunk: ) -> llm_entities.MessageChunk:
self.client.api_key = use_model.token_mgr.get_token() self.client.api_key = use_model.token_mgr.get_token()
args = {} args = {}
args['model'] = use_model.model_entity.name args['model'] = use_model.model_entity.name
@@ -158,7 +156,6 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
args['messages'] = messages args['messages'] = messages
if stream:
current_content = '' current_content = ''
args['stream'] = True args['stream'] = True
chunk_idx = 0 chunk_idx = 0
@@ -202,7 +199,6 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
req_messages: list[dict], req_messages: list[dict],
use_model: requester.RuntimeLLMModel, use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None, use_funcs: list[tools_entities.LLMFunction] = None,
stream: bool = False,
extra_args: dict[str, typing.Any] = {}, extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message: ) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token() self.client.api_key = use_model.token_mgr.get_token()
@@ -289,7 +285,6 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
model: requester.RuntimeLLMModel, model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message], messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None, funcs: typing.List[tools_entities.LLMFunction] = None,
stream: bool = False,
extra_args: dict[str, typing.Any] = {}, extra_args: dict[str, typing.Any] = {},
) -> llm_entities.MessageChunk: ) -> llm_entities.MessageChunk:
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
@@ -309,7 +304,6 @@ class OpenAIChatCompletions(requester.LLMAPIRequester):
req_messages=req_messages, req_messages=req_messages,
use_model=model, use_model=model,
use_funcs=funcs, use_funcs=funcs,
stream=stream,
extra_args=extra_args, extra_args=extra_args,
): ):
yield item yield item

View File

@@ -12,7 +12,6 @@ import re
import openai.types.chat.chat_completion as chat_completion import openai.types.chat.chat_completion as chat_completion
class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions): class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
"""Gitee AI ChatCompletions API 请求器""" """Gitee AI ChatCompletions API 请求器"""
@@ -56,7 +55,6 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
return message return message
async def _make_msg( async def _make_msg(
self, self,
chat_completion: chat_completion.ChatCompletion, chat_completion: chat_completion.ChatCompletion,
@@ -73,23 +71,25 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
# deepseek的reasoner模型 # deepseek的reasoner模型
if pipeline_config['trigger'].get('misc', '').get('remove_think'): if pipeline_config['trigger'].get('misc', '').get('remove_think'):
chatcmpl_message['content'] = re.sub(r'<think>.*?</think>', '', chatcmpl_message['content'], flags=re.DOTALL) chatcmpl_message['content'] = re.sub(
r'<think>.*?</think>', '', chatcmpl_message['content'], flags=re.DOTALL
)
else: else:
if reasoning_content is not None: if reasoning_content is not None:
chatcmpl_message['content'] = '<think>\n' + reasoning_content + '\n</think>\n' + chatcmpl_message['content'] chatcmpl_message['content'] = (
'<think>\n' + reasoning_content + '\n</think>\n' + chatcmpl_message['content']
)
message = llm_entities.Message(**chatcmpl_message) message = llm_entities.Message(**chatcmpl_message)
return message return message
async def _make_msg_chunk( async def _make_msg_chunk(
self, self,
pipeline_config: dict[str, typing.Any], pipeline_config: dict[str, typing.Any],
chat_completion: chat_completion.ChatCompletion, chat_completion: chat_completion.ChatCompletion,
idx: int, idx: int,
) -> llm_entities.MessageChunk: ) -> llm_entities.MessageChunk:
# 处理流式chunk和完整响应的差异 # 处理流式chunk和完整响应的差异
# print(chat_completion.choices[0]) # print(chat_completion.choices[0])
if hasattr(chat_completion, 'choices'): if hasattr(chat_completion, 'choices'):
@@ -104,7 +104,6 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
if 'role' not in delta or delta['role'] is None: if 'role' not in delta or delta['role'] is None:
delta['role'] = 'assistant' delta['role'] = 'assistant'
reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None
delta['content'] = '' if delta['content'] is None else delta['content'] delta['content'] = '' if delta['content'] is None else delta['content']
@@ -115,7 +114,7 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
if delta['content'] == '<think>': if delta['content'] == '<think>':
self.is_think = True self.is_think = True
delta['content'] = '' delta['content'] = ''
if delta['content'] == rf'</think>': if delta['content'] == r'</think>':
self.is_think = False self.is_think = False
delta['content'] = '' delta['content'] = ''
if not self.is_think: if not self.is_think:
@@ -126,7 +125,6 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
if reasoning_content is not None: if reasoning_content is not None:
delta['content'] += reasoning_content delta['content'] += reasoning_content
message = llm_entities.MessageChunk(**delta) message = llm_entities.MessageChunk(**delta)
return message return message
@@ -137,7 +135,6 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
req_messages: list[dict], req_messages: list[dict],
use_model: requester.RuntimeLLMModel, use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None, use_funcs: list[tools_entities.LLMFunction] = None,
stream: bool = False,
extra_args: dict[str, typing.Any] = {}, extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]:
self.client.api_key = use_model.token_mgr.get_token() self.client.api_key = use_model.token_mgr.get_token()
@@ -165,9 +162,8 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
args['messages'] = messages args['messages'] = messages
if stream:
current_content = '' current_content = ''
args["stream"] = True args['stream'] = True
chunk_idx = 0 chunk_idx = 0
self.is_content = False self.is_content = False
tool_calls_map: dict[str, llm_entities.ToolCall] = {} tool_calls_map: dict[str, llm_entities.ToolCall] = {}
@@ -186,15 +182,13 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
id=tool_call.id, id=tool_call.id,
type=tool_call.type, type=tool_call.type,
function=llm_entities.FunctionCall( function=llm_entities.FunctionCall(
name=tool_call.function.name if tool_call.function else '', name=tool_call.function.name if tool_call.function else '', arguments=''
arguments=''
), ),
) )
if tool_call.function and tool_call.function.arguments: if tool_call.function and tool_call.function.arguments:
# 流式处理中工具调用参数可能分多个chunk返回需要追加而不是覆盖 # 流式处理中工具调用参数可能分多个chunk返回需要追加而不是覆盖
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
chunk_idx += 1 chunk_idx += 1
chunk_choices = getattr(chunk, 'choices', None) chunk_choices = getattr(chunk, 'choices', None)
if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None):
@@ -202,7 +196,4 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
delta_message.content = current_content delta_message.content = current_content
if chunk_idx % 64 == 0 or delta_message.is_final: if chunk_idx % 64 == 0 or delta_message.is_final:
yield delta_message yield delta_message

View File

@@ -169,7 +169,6 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
args: dict, args: dict,
extra_body: dict = {}, extra_body: dict = {},
) -> chat_completion.ChatCompletion: ) -> chat_completion.ChatCompletion:
async for chunk in await self.client.chat.completions.create(**args, extra_body=extra_body): async for chunk in await self.client.chat.completions.create(**args, extra_body=extra_body):
yield chunk yield chunk
@@ -179,7 +178,6 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
chat_completion: chat_completion.ChatCompletion, chat_completion: chat_completion.ChatCompletion,
idx: int, idx: int,
) -> llm_entities.MessageChunk: ) -> llm_entities.MessageChunk:
# 处理流式chunk和完整响应的差异 # 处理流式chunk和完整响应的差异
# print(chat_completion.choices[0]) # print(chat_completion.choices[0])
if hasattr(chat_completion, 'choices'): if hasattr(chat_completion, 'choices'):
@@ -195,7 +193,6 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
if 'role' not in delta or delta['role'] is None: if 'role' not in delta or delta['role'] is None:
delta['role'] = 'assistant' delta['role'] = 'assistant'
reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None
delta['content'] = '' if delta['content'] is None else delta['content'] delta['content'] = '' if delta['content'] is None else delta['content']
@@ -219,7 +216,6 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
else: else:
delta['content'] += reasoning_content delta['content'] += reasoning_content
message = llm_entities.MessageChunk(**delta) message = llm_entities.MessageChunk(**delta)
return message return message
@@ -230,7 +226,6 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
req_messages: list[dict], req_messages: list[dict],
use_model: requester.RuntimeLLMModel, use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None, use_funcs: list[tools_entities.LLMFunction] = None,
stream: bool = False,
extra_args: dict[str, typing.Any] = {}, extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]:
self.client.api_key = use_model.token_mgr.get_token() self.client.api_key = use_model.token_mgr.get_token()
@@ -258,9 +253,8 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
args['messages'] = messages args['messages'] = messages
if stream:
current_content = '' current_content = ''
args["stream"] = True args['stream'] = True
chunk_idx = 0 chunk_idx = 0
self.is_content = False self.is_content = False
tool_calls_map: dict[str, llm_entities.ToolCall] = {} tool_calls_map: dict[str, llm_entities.ToolCall] = {}
@@ -279,15 +273,13 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
id=tool_call.id, id=tool_call.id,
type=tool_call.type, type=tool_call.type,
function=llm_entities.FunctionCall( function=llm_entities.FunctionCall(
name=tool_call.function.name if tool_call.function else '', name=tool_call.function.name if tool_call.function else '', arguments=''
arguments=''
), ),
) )
if tool_call.function and tool_call.function.arguments: if tool_call.function and tool_call.function.arguments:
# 流式处理中工具调用参数可能分多个chunk返回需要追加而不是覆盖 # 流式处理中工具调用参数可能分多个chunk返回需要追加而不是覆盖
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
chunk_idx += 1 chunk_idx += 1
chunk_choices = getattr(chunk, 'choices', None) chunk_choices = getattr(chunk, 'choices', None)
if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None):
@@ -295,12 +287,9 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
delta_message.content = current_content delta_message.content = current_content
if chunk_idx % 64 == 0 or delta_message.is_final: if chunk_idx % 64 == 0 or delta_message.is_final:
yield delta_message yield delta_message
# return # return
async def invoke_llm( async def invoke_llm(
self, self,
query: core_entities.Query, query: core_entities.Query,
@@ -340,14 +329,12 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
except openai.APIError as e: except openai.APIError as e:
raise errors.RequesterError(f'请求错误: {e.message}') raise errors.RequesterError(f'请求错误: {e.message}')
async def invoke_llm_stream( async def invoke_llm_stream(
self, self,
query: core_entities.Query, query: core_entities.Query,
model: requester.RuntimeLLMModel, model: requester.RuntimeLLMModel,
messages: typing.List[llm_entities.Message], messages: typing.List[llm_entities.Message],
funcs: typing.List[tools_entities.LLMFunction] = None, funcs: typing.List[tools_entities.LLMFunction] = None,
stream: bool = False,
extra_args: dict[str, typing.Any] = {}, extra_args: dict[str, typing.Any] = {},
) -> llm_entities.MessageChunk: ) -> llm_entities.MessageChunk:
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
@@ -367,7 +354,6 @@ class ModelScopeChatCompletions(requester.LLMAPIRequester):
req_messages=req_messages, req_messages=req_messages,
use_model=model, use_model=model,
use_funcs=funcs, use_funcs=funcs,
stream=stream,
extra_args=extra_args, extra_args=extra_args,
): ):
yield item yield item

View File

@@ -5,8 +5,8 @@ import typing
from . import chatcmpl from . import chatcmpl
import openai.types.chat.chat_completion as chat_completion import openai.types.chat.chat_completion as chat_completion
from .. import errors, requester from .. import requester
from ....core import entities as core_entities, app from ....core import entities as core_entities
from ... import entities as llm_entities from ... import entities as llm_entities
from ...tools import entities as tools_entities from ...tools import entities as tools_entities
import re import re
@@ -40,16 +40,19 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions):
# deepseek的reasoner模型 # deepseek的reasoner模型
if pipeline_config['trigger'].get('misc', '').get('remove_think'): if pipeline_config['trigger'].get('misc', '').get('remove_think'):
chatcmpl_message['content'] = re.sub(r'<think>.*?</think>', '', chatcmpl_message['content'], flags=re.DOTALL) chatcmpl_message['content'] = re.sub(
r'<think>.*?</think>', '', chatcmpl_message['content'], flags=re.DOTALL
)
else: else:
if reasoning_content is not None: if reasoning_content is not None:
chatcmpl_message['content'] = '<think>\n' + reasoning_content + '\n</think>\n' + chatcmpl_message['content'] chatcmpl_message['content'] = (
'<think>\n' + reasoning_content + '\n</think>\n' + chatcmpl_message['content']
)
message = llm_entities.Message(**chatcmpl_message) message = llm_entities.Message(**chatcmpl_message)
return message return message
async def _make_msg_chunk( async def _make_msg_chunk(
self, self,
pipeline_config: dict[str, typing.Any], pipeline_config: dict[str, typing.Any],
@@ -80,7 +83,7 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions):
if '<think>' in delta['content']: if '<think>' in delta['content']:
self.is_think = True self.is_think = True
delta['content'] = '' delta['content'] = ''
if rf'</think>' in delta['content']: if r'</think>' in delta['content']:
self.is_think = False self.is_think = False
delta['content'] = '' delta['content'] = ''
if not self.is_think: if not self.is_think:
@@ -95,14 +98,12 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions):
return message return message
async def _closure_stream( async def _closure_stream(
self, self,
query: core_entities.Query, query: core_entities.Query,
req_messages: list[dict], req_messages: list[dict],
use_model: requester.RuntimeLLMModel, use_model: requester.RuntimeLLMModel,
use_funcs: list[tools_entities.LLMFunction] = None, use_funcs: list[tools_entities.LLMFunction] = None,
stream: bool = False,
extra_args: dict[str, typing.Any] = {}, extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]:
self.client.api_key = use_model.token_mgr.get_token() self.client.api_key = use_model.token_mgr.get_token()
@@ -130,9 +131,8 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions):
args['messages'] = messages args['messages'] = messages
if stream:
current_content = '' current_content = ''
args["stream"] = True args['stream'] = True
chunk_idx = 0 chunk_idx = 0
self.is_content = False self.is_content = False
tool_calls_map: dict[str, llm_entities.ToolCall] = {} tool_calls_map: dict[str, llm_entities.ToolCall] = {}
@@ -151,8 +151,7 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions):
id=tool_call.id, id=tool_call.id,
type=tool_call.type, type=tool_call.type,
function=llm_entities.FunctionCall( function=llm_entities.FunctionCall(
name=tool_call.function.name if tool_call.function else '', name=tool_call.function.name if tool_call.function else '', arguments=''
arguments=''
), ),
) )
if tool_call.function and tool_call.function.arguments: if tool_call.function and tool_call.function.arguments:

View File

@@ -348,7 +348,9 @@ class DifyServiceAPIRunner(runner.RequestRunner):
except AttributeError: except AttributeError:
is_stream = False is_stream = False
batch_pending_index = 0 _ = is_stream
# batch_pending_index = 0
plain_text, image_ids = await self._preprocess_user_message(query) plain_text, image_ids = await self._preprocess_user_message(query)

View File

@@ -63,8 +63,7 @@ class LocalAgentRunner(runner.RequestRunner):
id=tool_call.id, id=tool_call.id,
type=tool_call.type, type=tool_call.type,
function=llm_entities.FunctionCall( function=llm_entities.FunctionCall(
name=tool_call.function.name if tool_call.function else '', name=tool_call.function.name if tool_call.function else '', arguments=''
arguments=''
), ),
) )
if tool_call.function and tool_call.function.arguments: if tool_call.function and tool_call.function.arguments:

View File

@@ -204,9 +204,9 @@ async def get_slack_image_to_base64(pic_url: str, bot_token: str):
try: try:
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get(pic_url, headers=headers) as resp: async with session.get(pic_url, headers=headers) as resp:
mime_type = resp.headers.get("Content-Type", "application/octet-stream") mime_type = resp.headers.get('Content-Type', 'application/octet-stream')
file_bytes = await resp.read() file_bytes = await resp.read()
base64_str = base64.b64encode(file_bytes).decode("utf-8") base64_str = base64.b64encode(file_bytes).decode('utf-8')
return f"data:{mime_type};base64,{base64_str}" return f'data:{mime_type};base64,{base64_str}'
except Exception as e: except Exception as e:
raise (e) raise (e)

View File

@@ -32,7 +32,7 @@ def import_dir(path: str):
rel_path = full_path.replace(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '') rel_path = full_path.replace(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '')
rel_path = rel_path[1:] rel_path = rel_path[1:]
rel_path = rel_path.replace('/', '.')[:-3] rel_path = rel_path.replace('/', '.')[:-3]
rel_path = rel_path.replace("\\",".") rel_path = rel_path.replace('\\', '.')
importlib.import_module(rel_path) importlib.import_module(rel_path)