perf(wecom): add supports for images

This commit is contained in:
wangcham
2025-01-20 21:24:46 -05:00
parent 60d4f3d77c
commit 7a01cff0c8
4 changed files with 312 additions and 110 deletions

View File

@@ -1,11 +1,14 @@
from quart import request
from .WXBizMsgCrypt3 import WXBizMsgCrypt
import base64
import binascii
import httpx
from quart import Quart
import xml.etree.ElementTree as ET
from typing import Callable, Dict, Any
from .wecomevent import WecomEvent
from pkg.platform.types import events as platform_events, message as platform_message
import aiofiles
class WecomClient():
@@ -42,7 +45,6 @@ class WecomClient():
else:
raise Exception(f"未获取access token: {data}")
async def get_users(self):
if not self.check_access_token_for_contacts():
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
@@ -88,6 +90,30 @@ class WecomClient():
data = response.json()
if data['errcode'] != 0:
raise Exception("Failed to send message: "+str(data))
async def send_image(self,user_id:str,agent_id:int,media_id:str):
if not await self.check_access_token():
self.access_token = await self.get_access_token(self.secret)
url = self.base_url+'/media/upload?access_token='+self.access_token
async with httpx.AsyncClient() as client:
params = {
"touser" : user_id,
"toparty" : "",
"totag":"",
"agentid" : agent_id,
"msgtype" : "image",
"image" : {
"media_id" : media_id,
},
"safe":0,
"enable_id_trans": 0,
"enable_duplicate_check": 0,
"duplicate_check_interval": 1800
}
response = await client.post(url,json=params)
data = response.json()
if data['errcode'] != 0:
raise Exception("Failed to send image: "+str(data))
async def send_private_msg(self,user_id:str, agent_id:int,content:str):
if not await self.check_access_token():
@@ -188,13 +214,92 @@ class WecomClient():
"MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None,
"AgentID": int(root.find("AgentID").text) if root.find("AgentID") is not None else None,
}
return message_data
if message_data["MsgType"] == "image":
message_data["MediaId"] = root.find("MediaId").text if root.find("MediaId") is not None else None
message_data["PicUrl"] = root.find("PicUrl").text if root.find("PicUrl") is not None else None
return message_data
@staticmethod
async def get_image_type(image_bytes: bytes) -> str:
"""
通过图片的magic numbers判断图片类型
"""
magic_numbers = {
b'\xFF\xD8\xFF': 'jpg',
b'\x89\x50\x4E\x47': 'png',
b'\x47\x49\x46': 'gif',
b'\x42\x4D': 'bmp',
b'\x00\x00\x01\x00': 'ico'
}
for magic, ext in magic_numbers.items():
if image_bytes.startswith(magic):
return ext
return 'jpg' # 默认返回jpg
async def upload_to_work(self, image: platform_message.Image):
"""
获取 media_id
"""
if not await self.check_access_token():
self.access_token = await self.get_access_token(self.secret)
url = self.base_url + '/media/upload?access_token=' + self.access_token + '&type=file'
file_bytes = None
file_name = "uploaded_file.txt"
# 获取文件的二进制数据
if image.path:
async with aiofiles.open(image.path, 'rb') as f:
file_bytes = await f.read()
file_name = image.path.split('/')[-1]
elif image.url:
file_bytes = await self.download_image_to_bytes(image.url)
file_name = image.url.split('/')[-1]
elif image.base64:
try:
base64_data = image.base64
if ',' in base64_data:
base64_data = base64_data.split(',', 1)[1]
padding = 4 - (len(base64_data) % 4) if len(base64_data) % 4 else 0
padded_base64 = base64_data + '=' * padding
file_bytes = base64.b64decode(padded_base64)
except binascii.Error as e:
raise ValueError(f"Invalid base64 string: {str(e)}")
else:
raise ValueError("image对象出错")
# 设置 multipart/form-data 格式的文件
boundary = "-------------------------acebdf13572468"
headers = {
'Content-Type': f'multipart/form-data; boundary={boundary}'
}
body = (
f"--{boundary}\r\n"
f"Content-Disposition: form-data; name=\"media\"; filename=\"{file_name}\"; filelength={len(file_bytes)}\r\n"
f"Content-Type: application/octet-stream\r\n\r\n"
).encode('utf-8') + file_bytes + f"\r\n--{boundary}--\r\n".encode('utf-8')
# 上传文件
async with httpx.AsyncClient() as client:
response = await client.post(url, headers=headers, content=body)
data = response.json()
if data.get('errcode', 0) != 0:
raise Exception("failed to upload file")
return data.get('media_id')
async def download_image_to_bytes(self,url:str) -> bytes:
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
return response.content
#进行media_id的获取
async def get_media_id(self, image: platform_message.Image):
media_id = await self.upload_to_work(image=image)
return media_id

View File

@@ -35,6 +35,13 @@ class WecomEvent(dict):
str: 事件类型。
"""
return self.get("MsgType", "")
@property
def picurl(self) -> str:
"""
图片链接
"""
return self.get("PicUrl")
@property
def detail_type(self) -> str:

View File

@@ -12,7 +12,6 @@ from pkg.platform.adapter import MessageSourceAdapter
from pkg.platform.types import events as platform_events, message as platform_message
from libs.wecom_api.wecomevent import WecomEvent
from pkg.core import app
from .. import adapter
from ...pipeline.longtext.strategies import forward
from ...core import app
@@ -20,24 +19,54 @@ from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities
from ...command.errors import ParamNotEnoughError
from ...utils import image
class WecomMessageConverter(adapter.MessageConverter):
@staticmethod
async def yiri2target(message_chain:platform_message.MessageChain):
content=''
for msg in message_chain:
if type(msg) is platform_message.Plain:
content+=msg.text
return content
@staticmethod
async def target2yiri(message:str,message_id:int = -1):
async def yiri2target(
message_chain: platform_message.MessageChain, bot: WecomClient
):
content_list = []
[
{
"type": "text",
"content": "text",
},
{
"type": "image",
"media_id": "media_id",
}
]
for msg in message_chain:
if type(msg) is platform_message.Plain:
content_list.append({
"type": "text",
"content": msg.text,
})
elif type(msg) is platform_message.Image:
content_list.append({
"type": "image",
"media_id": await bot.get_media_id(msg),
})
elif type(msg) is platform_message.Forward:
for node in msg.node_list:
content_list.extend((await WecomMessageConverter.yiri2target(node.message_chain, bot)))
else:
content_list.append({
"type": "text",
"content": str(msg),
})
return content_list
@staticmethod
async def target2yiri(message: str, message_id: int = -1):
yiri_msg_list = []
yiri_msg_list.append(
platform_message.Source(id = message_id,time = datetime.datetime.now())
platform_message.Source(id=message_id, time=datetime.datetime.now())
)
yiri_msg_list.append(platform_message.Plain(text=message))
@@ -45,31 +74,46 @@ class WecomMessageConverter(adapter.MessageConverter):
return chain
@staticmethod
async def target2yiri_image(picurl: str, message_id: int = -1):
yiri_msg_list = []
yiri_msg_list.append(
platform_message.Source(id=message_id, time=datetime.datetime.now())
)
image_base64, image_format = await image.get_wecom_image_base64(pic_url=picurl)
yiri_msg_list.append(platform_message.Image(base64=f"data:image/{image_format};base64,{image_base64}"))
chain = platform_message.MessageChain(yiri_msg_list)
return chain
class WecomEventConverter:
@staticmethod
async def yiri2target(event:platform_events.Event,bot_account_id:int) -> WecomEvent:
content = await WecomMessageConverter.yiri2target(event.message_chain)
if type(event) is platform_events.GroupMessage:
pass
if type(event) is platform_events.FriendMessage:
payload = {
"MsgType": "text",
"Content": content,
"FromUserName": event.sender.id,
"ToUserName": bot_account_id,
"CreateTime": int(datetime.datetime.now().timestamp()),
"AgentID": event.sender.nickname
}
@staticmethod
async def yiri2target(
event: platform_events.Event, bot_account_id: int, bot: WecomClient
) -> WecomEvent:
# only for extracting user information
if type(event) is platform_events.GroupMessage:
pass
if type(event) is platform_events.FriendMessage:
payload = {
"MsgType": "text",
"Content": '',
"FromUserName": event.sender.id,
"ToUserName": bot_account_id,
"CreateTime": int(datetime.datetime.now().timestamp()),
"AgentID": event.sender.nickname,
}
wecom_event = WecomEvent.from_payload(payload=payload)
if not wecom_event:
raise ValueError("无法从 message_data 构造 WecomEvent 对象")
return wecom_event
@staticmethod
async def target2yiri(event: WecomEvent):
"""
@@ -82,112 +126,133 @@ class WecomEventConverter:
platform_events.FriendMessage: 转换后的 FriendMessage 对象。
"""
# 转换消息链
yiri_chain = await WecomMessageConverter.target2yiri(
event.message, event.message_id
)
# 判断消息类型并进行转换
# if event.message_type == "private": 默认消息都是从好友发出
friend = platform_entities.Friend(
id=event.user_id,
nickname=str(event.agent_id),
remark="",
)
return platform_events.FriendMessage(
sender=friend,
message_chain=yiri_chain,
time=event.timestamp
if event.type == "text":
yiri_chain = await WecomMessageConverter.target2yiri(
event.message, event.message_id
)
friend = platform_entities.Friend(
id=event.user_id,
nickname=str(event.agent_id),
remark="",
)
return platform_events.FriendMessage(
sender=friend, message_chain=yiri_chain, time=event.timestamp
)
elif event.type == "image":
friend = platform_entities.Friend(
id=event.user_id,
nickname=str(event.agent_id),
remark="",
)
yiri_chain = await WecomMessageConverter.target2yiri_image(
picurl=event.picurl, message_id=event.message_id
)
return platform_events.FriendMessage(
sender=friend, message_chain=yiri_chain, time=event.timestamp
)
@adapter.adapter_class("wecom")
class WecomeAdapter(adapter.MessageSourceAdapter):
bot:WecomClient
ap:app.Application
bot_account_id:str
message_converter:WecomMessageConverter = WecomMessageConverter()
event_converter:WecomEventConverter = WecomEventConverter()
config:dict
ap:app.Application
bot: WecomClient
ap: app.Application
bot_account_id: str
message_converter: WecomMessageConverter = WecomMessageConverter()
event_converter: WecomEventConverter = WecomEventConverter()
config: dict
ap: app.Application
def __init__(self, config: dict, ap:app.Application):
def __init__(self, config: dict, ap: app.Application):
self.config = config
#这里需要对config里的内容换成企业微信的config。是config:corpid,token......
self.ap = ap
required_keys = ["corpid","secret","token","EncodingAESKey","contacts_secret"]
required_keys = [
"corpid",
"secret",
"token",
"EncodingAESKey",
"contacts_secret",
]
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise ParamNotEnoughError("企业微信缺少相关配置项,请查看文档或联系管理员")
self.bot = WecomClient(
corpid=config['corpid'],
secret=config['secret'],
token=config['token'],
EncodingAESKey=config['EncodingAESKey'],
contacts_secret=config['contacts_secret']
corpid=config["corpid"],
secret=config["secret"],
token=config["token"],
EncodingAESKey=config["EncodingAESKey"],
contacts_secret=config["contacts_secret"],
)
async def reply_message(self,message_source:platform_events.MessageEvent,message:platform_message.MessageChain,
quote_origin:bool=False,
async def reply_message(
self,
message_source: platform_events.MessageEvent,
message: platform_message.MessageChain,
quote_origin: bool = False,
):
Wecom_event = await WecomEventConverter.yiri2target(message_source,self.bot_account_id)
Wecom_msg = await WecomMessageConverter.yiri2target(message)
# message_converter传回一个消息str
user_id = Wecom_event.user_id
agent_id = Wecom_event.agent_id
return await self.bot.send_private_msg(user_id=user_id,agent_id=agent_id,content=Wecom_msg)
Wecom_event = await WecomEventConverter.yiri2target(
message_source, self.bot_account_id, self.bot
)
content_list = await WecomMessageConverter.yiri2target(message, self.bot)
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
for content in content_list:
if content["type"] == "text":
await self.bot.send_private_msg(Wecom_event.user_id, Wecom_event.agent_id, content["content"])
elif content["type"] == "image":
await self.bot.send_image(Wecom_event.user_id, Wecom_event.agent_id, content["media_id"])
async def send_message(
self, target_type: str, target_id: str, message: platform_message.MessageChain
):
pass
def register_listener(
self,
event_type:typing.Type[platform_events.Event],
callback:typing.Callable[[platform_events.Event,adapter.MessageSourceAdapter],None],
self,
event_type: typing.Type[platform_events.Event],
callback: typing.Callable[
[platform_events.Event, adapter.MessageSourceAdapter], None
],
):
async def on_message(event:WecomEvent):
async def on_message(event: WecomEvent):
self.bot_account_id = event.receiver_id
try:
return await callback(await self.event_converter.target2yiri(event),self)
return await callback(
await self.event_converter.target2yiri(event), self
)
except:
traceback.print_exc()
if event_type == platform_events.FriendMessage:
self.bot.on_message("text")(on_message)
self.bot.on_message("image")(on_message)
elif event_type == platform_events.GroupMessage:
pass
async def run_async(self):
async def shutdown_trigger_placeholder():
while True:
await asyncio.sleep(1)
await self.bot.run_task(host=self.config['host'],port=self.config['port'],shutdown_trigger=shutdown_trigger_placeholder)
await self.bot.run_task(
host=self.config["host"],
port=self.config["port"],
shutdown_trigger=shutdown_trigger_placeholder,
)
async def kill(self) -> bool:
return False
async def unregister_listener(self, event_type: type, callback: typing.Callable[[platform_events.Event, MessageSourceAdapter], None]):
async def unregister_listener(
self,
event_type: type,
callback: typing.Callable[[platform_events.Event, MessageSourceAdapter], None],
):
return super().unregister_listener(event_type, callback)

View File

@@ -7,6 +7,31 @@ import ssl
import aiohttp
import PIL.Image
async def get_wecom_image_base64(pic_url: str) -> tuple[str, str]:
"""
下载企业微信图片并转换为 base64
:param pic_url: 企业微信图片URL
:return: (base64_str, image_format)
"""
async with aiohttp.ClientSession() as session:
async with session.get(pic_url) as response:
if response.status != 200:
raise Exception(f"Failed to download image: {response.status}")
# 读取图片数据
image_data = await response.read()
# 获取图片格式
content_type = response.headers.get('Content-Type', '')
image_format = content_type.split('/')[-1] # 例如 'image/jpeg' -> 'jpeg'
# 转换为 base64
import base64
image_base64 = base64.b64encode(image_data).decode('utf-8')
return image_base64, image_format
def get_qq_image_downloadable_url(image_url: str) -> tuple[str, dict]:
"""获取QQ图片的下载链接"""