mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 19:37:36 +08:00
perf(wecom): add supports for images
This commit is contained in:
@@ -1,11 +1,14 @@
|
||||
from quart import request
|
||||
from .WXBizMsgCrypt3 import WXBizMsgCrypt
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
import httpx
|
||||
from quart import Quart
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Callable, Dict, Any
|
||||
from .wecomevent import WecomEvent
|
||||
from pkg.platform.types import events as platform_events, message as platform_message
|
||||
import aiofiles
|
||||
|
||||
|
||||
class WecomClient():
|
||||
@@ -42,7 +45,6 @@ class WecomClient():
|
||||
else:
|
||||
raise Exception(f"未获取access token: {data}")
|
||||
|
||||
|
||||
async def get_users(self):
|
||||
if not self.check_access_token_for_contacts():
|
||||
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
|
||||
@@ -89,6 +91,30 @@ class WecomClient():
|
||||
if data['errcode'] != 0:
|
||||
raise Exception("Failed to send message: "+str(data))
|
||||
|
||||
async def send_image(self,user_id:str,agent_id:int,media_id:str):
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
url = self.base_url+'/media/upload?access_token='+self.access_token
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = {
|
||||
"touser" : user_id,
|
||||
"toparty" : "",
|
||||
"totag":"",
|
||||
"agentid" : agent_id,
|
||||
"msgtype" : "image",
|
||||
"image" : {
|
||||
"media_id" : media_id,
|
||||
},
|
||||
"safe":0,
|
||||
"enable_id_trans": 0,
|
||||
"enable_duplicate_check": 0,
|
||||
"duplicate_check_interval": 1800
|
||||
}
|
||||
response = await client.post(url,json=params)
|
||||
data = response.json()
|
||||
if data['errcode'] != 0:
|
||||
raise Exception("Failed to send image: "+str(data))
|
||||
|
||||
async def send_private_msg(self,user_id:str, agent_id:int,content:str):
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
@@ -188,13 +214,92 @@ class WecomClient():
|
||||
"MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None,
|
||||
"AgentID": int(root.find("AgentID").text) if root.find("AgentID") is not None else None,
|
||||
}
|
||||
if message_data["MsgType"] == "image":
|
||||
message_data["MediaId"] = root.find("MediaId").text if root.find("MediaId") is not None else None
|
||||
message_data["PicUrl"] = root.find("PicUrl").text if root.find("PicUrl") is not None else None
|
||||
|
||||
return message_data
|
||||
|
||||
@staticmethod
|
||||
async def get_image_type(image_bytes: bytes) -> str:
|
||||
"""
|
||||
通过图片的magic numbers判断图片类型
|
||||
"""
|
||||
magic_numbers = {
|
||||
b'\xFF\xD8\xFF': 'jpg',
|
||||
b'\x89\x50\x4E\x47': 'png',
|
||||
b'\x47\x49\x46': 'gif',
|
||||
b'\x42\x4D': 'bmp',
|
||||
b'\x00\x00\x01\x00': 'ico'
|
||||
}
|
||||
|
||||
for magic, ext in magic_numbers.items():
|
||||
if image_bytes.startswith(magic):
|
||||
return ext
|
||||
return 'jpg' # 默认返回jpg
|
||||
|
||||
|
||||
async def upload_to_work(self, image: platform_message.Image):
|
||||
"""
|
||||
获取 media_id
|
||||
"""
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
|
||||
url = self.base_url + '/media/upload?access_token=' + self.access_token + '&type=file'
|
||||
file_bytes = None
|
||||
file_name = "uploaded_file.txt"
|
||||
|
||||
# 获取文件的二进制数据
|
||||
if image.path:
|
||||
async with aiofiles.open(image.path, 'rb') as f:
|
||||
file_bytes = await f.read()
|
||||
file_name = image.path.split('/')[-1]
|
||||
elif image.url:
|
||||
file_bytes = await self.download_image_to_bytes(image.url)
|
||||
file_name = image.url.split('/')[-1]
|
||||
elif image.base64:
|
||||
try:
|
||||
base64_data = image.base64
|
||||
if ',' in base64_data:
|
||||
base64_data = base64_data.split(',', 1)[1]
|
||||
padding = 4 - (len(base64_data) % 4) if len(base64_data) % 4 else 0
|
||||
padded_base64 = base64_data + '=' * padding
|
||||
file_bytes = base64.b64decode(padded_base64)
|
||||
except binascii.Error as e:
|
||||
raise ValueError(f"Invalid base64 string: {str(e)}")
|
||||
else:
|
||||
raise ValueError("image对象出错")
|
||||
|
||||
# 设置 multipart/form-data 格式的文件
|
||||
boundary = "-------------------------acebdf13572468"
|
||||
headers = {
|
||||
'Content-Type': f'multipart/form-data; boundary={boundary}'
|
||||
}
|
||||
body = (
|
||||
f"--{boundary}\r\n"
|
||||
f"Content-Disposition: form-data; name=\"media\"; filename=\"{file_name}\"; filelength={len(file_bytes)}\r\n"
|
||||
f"Content-Type: application/octet-stream\r\n\r\n"
|
||||
).encode('utf-8') + file_bytes + f"\r\n--{boundary}--\r\n".encode('utf-8')
|
||||
|
||||
# 上传文件
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, headers=headers, content=body)
|
||||
data = response.json()
|
||||
if data.get('errcode', 0) != 0:
|
||||
raise Exception("failed to upload file")
|
||||
|
||||
return data.get('media_id')
|
||||
|
||||
|
||||
async def download_image_to_bytes(self,url:str) -> bytes:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
#进行media_id的获取
|
||||
async def get_media_id(self, image: platform_message.Image):
|
||||
|
||||
|
||||
|
||||
media_id = await self.upload_to_work(image=image)
|
||||
return media_id
|
||||
|
||||
@@ -36,6 +36,13 @@ class WecomEvent(dict):
|
||||
"""
|
||||
return self.get("MsgType", "")
|
||||
|
||||
@property
|
||||
def picurl(self) -> str:
|
||||
"""
|
||||
图片链接
|
||||
"""
|
||||
return self.get("PicUrl")
|
||||
|
||||
@property
|
||||
def detail_type(self) -> str:
|
||||
"""
|
||||
|
||||
@@ -12,7 +12,6 @@ from pkg.platform.adapter import MessageSourceAdapter
|
||||
from pkg.platform.types import events as platform_events, message as platform_message
|
||||
from libs.wecom_api.wecomevent import WecomEvent
|
||||
from pkg.core import app
|
||||
|
||||
from .. import adapter
|
||||
from ...pipeline.longtext.strategies import forward
|
||||
from ...core import app
|
||||
@@ -20,24 +19,54 @@ from ..types import message as platform_message
|
||||
from ..types import events as platform_events
|
||||
from ..types import entities as platform_entities
|
||||
from ...command.errors import ParamNotEnoughError
|
||||
|
||||
from ...utils import image
|
||||
|
||||
class WecomMessageConverter(adapter.MessageConverter):
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(message_chain:platform_message.MessageChain):
|
||||
content=''
|
||||
async def yiri2target(
|
||||
message_chain: platform_message.MessageChain, bot: WecomClient
|
||||
):
|
||||
content_list = []
|
||||
|
||||
[
|
||||
{
|
||||
"type": "text",
|
||||
"content": "text",
|
||||
},
|
||||
{
|
||||
"type": "image",
|
||||
"media_id": "media_id",
|
||||
}
|
||||
]
|
||||
|
||||
for msg in message_chain:
|
||||
if type(msg) is platform_message.Plain:
|
||||
content+=msg.text
|
||||
|
||||
return content
|
||||
content_list.append({
|
||||
"type": "text",
|
||||
"content": msg.text,
|
||||
})
|
||||
elif type(msg) is platform_message.Image:
|
||||
content_list.append({
|
||||
"type": "image",
|
||||
"media_id": await bot.get_media_id(msg),
|
||||
})
|
||||
elif type(msg) is platform_message.Forward:
|
||||
for node in msg.node_list:
|
||||
content_list.extend((await WecomMessageConverter.yiri2target(node.message_chain, bot)))
|
||||
else:
|
||||
content_list.append({
|
||||
"type": "text",
|
||||
"content": str(msg),
|
||||
})
|
||||
|
||||
return content_list
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri(message:str,message_id:int = -1):
|
||||
async def target2yiri(message: str, message_id: int = -1):
|
||||
yiri_msg_list = []
|
||||
yiri_msg_list.append(
|
||||
platform_message.Source(id = message_id,time = datetime.datetime.now())
|
||||
platform_message.Source(id=message_id, time=datetime.datetime.now())
|
||||
)
|
||||
|
||||
yiri_msg_list.append(platform_message.Plain(text=message))
|
||||
@@ -45,29 +74,44 @@ class WecomMessageConverter(adapter.MessageConverter):
|
||||
|
||||
return chain
|
||||
|
||||
@staticmethod
|
||||
async def target2yiri_image(picurl: str, message_id: int = -1):
|
||||
yiri_msg_list = []
|
||||
yiri_msg_list.append(
|
||||
platform_message.Source(id=message_id, time=datetime.datetime.now())
|
||||
)
|
||||
image_base64, image_format = await image.get_wecom_image_base64(pic_url=picurl)
|
||||
yiri_msg_list.append(platform_message.Image(base64=f"data:image/{image_format};base64,{image_base64}"))
|
||||
chain = platform_message.MessageChain(yiri_msg_list)
|
||||
|
||||
return chain
|
||||
|
||||
|
||||
class WecomEventConverter:
|
||||
|
||||
@staticmethod
|
||||
async def yiri2target(event:platform_events.Event,bot_account_id:int) -> WecomEvent:
|
||||
content = await WecomMessageConverter.yiri2target(event.message_chain)
|
||||
async def yiri2target(
|
||||
event: platform_events.Event, bot_account_id: int, bot: WecomClient
|
||||
) -> WecomEvent:
|
||||
# only for extracting user information
|
||||
|
||||
if type(event) is platform_events.GroupMessage:
|
||||
pass
|
||||
if type(event) is platform_events.GroupMessage:
|
||||
pass
|
||||
|
||||
if type(event) is platform_events.FriendMessage:
|
||||
payload = {
|
||||
"MsgType": "text",
|
||||
"Content": content,
|
||||
"FromUserName": event.sender.id,
|
||||
"ToUserName": bot_account_id,
|
||||
"CreateTime": int(datetime.datetime.now().timestamp()),
|
||||
"AgentID": event.sender.nickname
|
||||
}
|
||||
if type(event) is platform_events.FriendMessage:
|
||||
|
||||
payload = {
|
||||
"MsgType": "text",
|
||||
"Content": '',
|
||||
"FromUserName": event.sender.id,
|
||||
"ToUserName": bot_account_id,
|
||||
"CreateTime": int(datetime.datetime.now().timestamp()),
|
||||
"AgentID": event.sender.nickname,
|
||||
}
|
||||
wecom_event = WecomEvent.from_payload(payload=payload)
|
||||
if not wecom_event:
|
||||
raise ValueError("无法从 message_data 构造 WecomEvent 对象")
|
||||
|
||||
return wecom_event
|
||||
|
||||
@staticmethod
|
||||
@@ -82,86 +126,113 @@ class WecomEventConverter:
|
||||
platform_events.FriendMessage: 转换后的 FriendMessage 对象。
|
||||
"""
|
||||
# 转换消息链
|
||||
yiri_chain = await WecomMessageConverter.target2yiri(
|
||||
event.message, event.message_id
|
||||
)
|
||||
|
||||
# 判断消息类型并进行转换
|
||||
# if event.message_type == "private": 默认消息都是从好友发出
|
||||
|
||||
friend = platform_entities.Friend(
|
||||
id=event.user_id,
|
||||
nickname=str(event.agent_id),
|
||||
remark="",
|
||||
)
|
||||
|
||||
return platform_events.FriendMessage(
|
||||
sender=friend,
|
||||
message_chain=yiri_chain,
|
||||
time=event.timestamp
|
||||
if event.type == "text":
|
||||
yiri_chain = await WecomMessageConverter.target2yiri(
|
||||
event.message, event.message_id
|
||||
)
|
||||
|
||||
friend = platform_entities.Friend(
|
||||
id=event.user_id,
|
||||
nickname=str(event.agent_id),
|
||||
remark="",
|
||||
)
|
||||
|
||||
return platform_events.FriendMessage(
|
||||
sender=friend, message_chain=yiri_chain, time=event.timestamp
|
||||
)
|
||||
elif event.type == "image":
|
||||
friend = platform_entities.Friend(
|
||||
id=event.user_id,
|
||||
nickname=str(event.agent_id),
|
||||
remark="",
|
||||
)
|
||||
|
||||
yiri_chain = await WecomMessageConverter.target2yiri_image(
|
||||
picurl=event.picurl, message_id=event.message_id
|
||||
)
|
||||
|
||||
return platform_events.FriendMessage(
|
||||
sender=friend, message_chain=yiri_chain, time=event.timestamp
|
||||
)
|
||||
|
||||
|
||||
@adapter.adapter_class("wecom")
|
||||
class WecomeAdapter(adapter.MessageSourceAdapter):
|
||||
|
||||
bot:WecomClient
|
||||
ap:app.Application
|
||||
bot_account_id:str
|
||||
message_converter:WecomMessageConverter = WecomMessageConverter()
|
||||
event_converter:WecomEventConverter = WecomEventConverter()
|
||||
config:dict
|
||||
ap:app.Application
|
||||
bot: WecomClient
|
||||
ap: app.Application
|
||||
bot_account_id: str
|
||||
message_converter: WecomMessageConverter = WecomMessageConverter()
|
||||
event_converter: WecomEventConverter = WecomEventConverter()
|
||||
config: dict
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, config: dict, ap:app.Application):
|
||||
def __init__(self, config: dict, ap: app.Application):
|
||||
self.config = config
|
||||
#这里需要对config里的内容换成企业微信的config。是config:corpid,token......
|
||||
|
||||
self.ap = ap
|
||||
|
||||
required_keys = ["corpid","secret","token","EncodingAESKey","contacts_secret"]
|
||||
required_keys = [
|
||||
"corpid",
|
||||
"secret",
|
||||
"token",
|
||||
"EncodingAESKey",
|
||||
"contacts_secret",
|
||||
]
|
||||
missing_keys = [key for key in required_keys if key not in config]
|
||||
if missing_keys:
|
||||
raise ParamNotEnoughError("企业微信缺少相关配置项,请查看文档或联系管理员")
|
||||
|
||||
self.bot = WecomClient(
|
||||
corpid=config['corpid'],
|
||||
secret=config['secret'],
|
||||
token=config['token'],
|
||||
EncodingAESKey=config['EncodingAESKey'],
|
||||
contacts_secret=config['contacts_secret']
|
||||
corpid=config["corpid"],
|
||||
secret=config["secret"],
|
||||
token=config["token"],
|
||||
EncodingAESKey=config["EncodingAESKey"],
|
||||
contacts_secret=config["contacts_secret"],
|
||||
)
|
||||
|
||||
async def reply_message(self,message_source:platform_events.MessageEvent,message:platform_message.MessageChain,
|
||||
quote_origin:bool=False,
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
Wecom_event = await WecomEventConverter.yiri2target(message_source,self.bot_account_id)
|
||||
Wecom_msg = await WecomMessageConverter.yiri2target(message)
|
||||
# message_converter传回一个消息str
|
||||
|
||||
user_id = Wecom_event.user_id
|
||||
agent_id = Wecom_event.agent_id
|
||||
return await self.bot.send_private_msg(user_id=user_id,agent_id=agent_id,content=Wecom_msg)
|
||||
Wecom_event = await WecomEventConverter.yiri2target(
|
||||
message_source, self.bot_account_id, self.bot
|
||||
)
|
||||
content_list = await WecomMessageConverter.yiri2target(message, self.bot)
|
||||
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
for content in content_list:
|
||||
if content["type"] == "text":
|
||||
await self.bot.send_private_msg(Wecom_event.user_id, Wecom_event.agent_id, content["content"])
|
||||
elif content["type"] == "image":
|
||||
await self.bot.send_image(Wecom_event.user_id, Wecom_event.agent_id, content["media_id"])
|
||||
|
||||
async def send_message(
|
||||
self, target_type: str, target_id: str, message: platform_message.MessageChain
|
||||
):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type:typing.Type[platform_events.Event],
|
||||
callback:typing.Callable[[platform_events.Event,adapter.MessageSourceAdapter],None],
|
||||
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[
|
||||
[platform_events.Event, adapter.MessageSourceAdapter], None
|
||||
],
|
||||
):
|
||||
async def on_message(event:WecomEvent):
|
||||
async def on_message(event: WecomEvent):
|
||||
self.bot_account_id = event.receiver_id
|
||||
try:
|
||||
return await callback(await self.event_converter.target2yiri(event),self)
|
||||
return await callback(
|
||||
await self.event_converter.target2yiri(event), self
|
||||
)
|
||||
except:
|
||||
traceback.print_exc()
|
||||
|
||||
if event_type == platform_events.FriendMessage:
|
||||
self.bot.on_message("text")(on_message)
|
||||
self.bot.on_message("image")(on_message)
|
||||
elif event_type == platform_events.GroupMessage:
|
||||
pass
|
||||
|
||||
@@ -170,24 +241,18 @@ class WecomeAdapter(adapter.MessageSourceAdapter):
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
|
||||
await self.bot.run_task(host=self.config['host'],port=self.config['port'],shutdown_trigger=shutdown_trigger_placeholder)
|
||||
await self.bot.run_task(
|
||||
host=self.config["host"],
|
||||
port=self.config["port"],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
|
||||
async def unregister_listener(self, event_type: type, callback: typing.Callable[[platform_events.Event, MessageSourceAdapter], None]):
|
||||
async def unregister_listener(
|
||||
self,
|
||||
event_type: type,
|
||||
callback: typing.Callable[[platform_events.Event, MessageSourceAdapter], None],
|
||||
):
|
||||
return super().unregister_listener(event_type, callback)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,31 @@ import ssl
|
||||
import aiohttp
|
||||
import PIL.Image
|
||||
|
||||
async def get_wecom_image_base64(pic_url: str) -> tuple[str, str]:
|
||||
"""
|
||||
下载企业微信图片并转换为 base64
|
||||
:param pic_url: 企业微信图片URL
|
||||
:return: (base64_str, image_format)
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(pic_url) as response:
|
||||
if response.status != 200:
|
||||
raise Exception(f"Failed to download image: {response.status}")
|
||||
|
||||
# 读取图片数据
|
||||
image_data = await response.read()
|
||||
|
||||
# 获取图片格式
|
||||
content_type = response.headers.get('Content-Type', '')
|
||||
image_format = content_type.split('/')[-1] # 例如 'image/jpeg' -> 'jpeg'
|
||||
|
||||
# 转换为 base64
|
||||
import base64
|
||||
image_base64 = base64.b64encode(image_data).decode('utf-8')
|
||||
|
||||
return image_base64, image_format
|
||||
|
||||
|
||||
|
||||
def get_qq_image_downloadable_url(image_url: str) -> tuple[str, dict]:
|
||||
"""获取QQ图片的下载链接"""
|
||||
|
||||
Reference in New Issue
Block a user