feat: add supports for wecom

This commit is contained in:
wangcham
2025-01-12 05:09:53 -05:00
parent fd30022065
commit 60d4f3d77c
19 changed files with 939 additions and 58 deletions

5
.gitignore vendored
View File

@@ -36,4 +36,7 @@ res/instance_id.json
.DS_Store
/data
botpy.log*
/poc
/poc
/libs/wecom_api/test.py
/venv

View File

@@ -0,0 +1,279 @@
#!/usr/bin/env python
# -*- encoding:utf-8 -*-
""" 对企业微信发送给企业后台的消息加解密示例代码.
@copyright: Copyright (c) 1998-2014 Tencent Inc.
"""
# ------------------------------------------------------------------------
import logging
import base64
import random
import hashlib
import time
import struct
from Crypto.Cipher import AES
import xml.etree.cElementTree as ET
import socket
from . import ierror
"""
Crypto.Cipher包已不再维护开发者可以通过以下命令下载安装最新版的加解密工具包
pip install pycryptodome
"""
class FormatException(Exception):
pass
def throw_exception(message, exception_class=FormatException):
"""my define raise exception function"""
raise exception_class(message)
class SHA1:
"""计算企业微信的消息签名接口"""
def getSHA1(self, token, timestamp, nonce, encrypt):
"""用SHA1算法生成安全签名
@param token: 票据
@param timestamp: 时间戳
@param encrypt: 密文
@param nonce: 随机字符串
@return: 安全签名
"""
try:
sortlist = [token, timestamp, nonce, encrypt]
sortlist.sort()
sha = hashlib.sha1()
sha.update("".join(sortlist).encode())
return ierror.WXBizMsgCrypt_OK, sha.hexdigest()
except Exception as e:
logger = logging.getLogger()
logger.error(e)
return ierror.WXBizMsgCrypt_ComputeSignature_Error, None
class XMLParse:
"""提供提取消息格式中的密文及生成回复消息格式的接口"""
# xml消息模板
AES_TEXT_RESPONSE_TEMPLATE = """<xml>
<Encrypt><![CDATA[%(msg_encrypt)s]]></Encrypt>
<MsgSignature><![CDATA[%(msg_signaturet)s]]></MsgSignature>
<TimeStamp>%(timestamp)s</TimeStamp>
<Nonce><![CDATA[%(nonce)s]]></Nonce>
</xml>"""
def extract(self, xmltext):
"""提取出xml数据包中的加密消息
@param xmltext: 待提取的xml字符串
@return: 提取出的加密消息字符串
"""
try:
xml_tree = ET.fromstring(xmltext)
encrypt = xml_tree.find("Encrypt")
return ierror.WXBizMsgCrypt_OK, encrypt.text
except Exception as e:
logger = logging.getLogger()
logger.error(e)
return ierror.WXBizMsgCrypt_ParseXml_Error, None
def generate(self, encrypt, signature, timestamp, nonce):
"""生成xml消息
@param encrypt: 加密后的消息密文
@param signature: 安全签名
@param timestamp: 时间戳
@param nonce: 随机字符串
@return: 生成的xml字符串
"""
resp_dict = {
'msg_encrypt': encrypt,
'msg_signaturet': signature,
'timestamp': timestamp,
'nonce': nonce,
}
resp_xml = self.AES_TEXT_RESPONSE_TEMPLATE % resp_dict
return resp_xml
class PKCS7Encoder():
"""提供基于PKCS7算法的加解密接口"""
block_size = 32
def encode(self, text):
""" 对需要加密的明文进行填充补位
@param text: 需要进行填充补位操作的明文
@return: 补齐明文字符串
"""
text_length = len(text)
# 计算需要填充的位数
amount_to_pad = self.block_size - (text_length % self.block_size)
if amount_to_pad == 0:
amount_to_pad = self.block_size
# 获得补位所用的字符
pad = chr(amount_to_pad)
return text + (pad * amount_to_pad).encode()
def decode(self, decrypted):
"""删除解密后明文的补位字符
@param decrypted: 解密后的明文
@return: 删除补位字符后的明文
"""
pad = ord(decrypted[-1])
if pad < 1 or pad > 32:
pad = 0
return decrypted[:-pad]
class Prpcrypt(object):
"""提供接收和推送给企业微信消息的加解密接口"""
def __init__(self, key):
# self.key = base64.b64decode(key+"=")
self.key = key
# 设置加解密模式为AES的CBC模式
self.mode = AES.MODE_CBC
def encrypt(self, text, receiveid):
"""对明文进行加密
@param text: 需要加密的明文
@return: 加密得到的字符串
"""
# 16位随机字符串添加到明文开头
text = text.encode()
text = self.get_random_str() + struct.pack("I", socket.htonl(len(text))) + text + receiveid.encode()
# 使用自定义的填充方式对明文进行补位填充
pkcs7 = PKCS7Encoder()
text = pkcs7.encode(text)
# 加密
cryptor = AES.new(self.key, self.mode, self.key[:16])
try:
ciphertext = cryptor.encrypt(text)
# 使用BASE64对加密后的字符串进行编码
return ierror.WXBizMsgCrypt_OK, base64.b64encode(ciphertext)
except Exception as e:
logger = logging.getLogger()
logger.error(e)
return ierror.WXBizMsgCrypt_EncryptAES_Error, None
def decrypt(self, text, receiveid):
"""对解密后的明文进行补位删除
@param text: 密文
@return: 删除填充补位后的明文
"""
try:
cryptor = AES.new(self.key, self.mode, self.key[:16])
# 使用BASE64对密文进行解码然后AES-CBC解密
plain_text = cryptor.decrypt(base64.b64decode(text))
except Exception as e:
logger = logging.getLogger()
logger.error(e)
return ierror.WXBizMsgCrypt_DecryptAES_Error, None
try:
pad = plain_text[-1]
# 去掉补位字符串
# pkcs7 = PKCS7Encoder()
# plain_text = pkcs7.encode(plain_text)
# 去除16位随机字符串
content = plain_text[16:-pad]
xml_len = socket.ntohl(struct.unpack("I", content[: 4])[0])
xml_content = content[4: xml_len + 4]
from_receiveid = content[xml_len + 4:]
except Exception as e:
logger = logging.getLogger()
logger.error(e)
return ierror.WXBizMsgCrypt_IllegalBuffer, None
if from_receiveid.decode('utf8') != receiveid:
return ierror.WXBizMsgCrypt_ValidateCorpid_Error, None
return 0, xml_content
def get_random_str(self):
""" 随机生成16位字符串
@return: 16位字符串
"""
return str(random.randint(1000000000000000, 9999999999999999)).encode()
class WXBizMsgCrypt(object):
# 构造函数
def __init__(self, sToken, sEncodingAESKey, sReceiveId):
try:
self.key = base64.b64decode(sEncodingAESKey + "=")
assert len(self.key) == 32
except:
throw_exception("[error]: EncodingAESKey unvalid !", FormatException)
# return ierror.WXBizMsgCrypt_IllegalAesKey,None
self.m_sToken = sToken
self.m_sReceiveId = sReceiveId
# 验证URL
# @param sMsgSignature: 签名串对应URL参数的msg_signature
# @param sTimeStamp: 时间戳对应URL参数的timestamp
# @param sNonce: 随机串对应URL参数的nonce
# @param sEchoStr: 随机串对应URL参数的echostr
# @param sReplyEchoStr: 解密之后的echostr当return返回0时有效
# @return成功0失败返回对应的错误码
def VerifyURL(self, sMsgSignature, sTimeStamp, sNonce, sEchoStr):
sha1 = SHA1()
ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, sEchoStr)
if ret != 0:
return ret, None
if not signature == sMsgSignature:
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
pc = Prpcrypt(self.key)
ret, sReplyEchoStr = pc.decrypt(sEchoStr, self.m_sReceiveId)
return ret, sReplyEchoStr
def EncryptMsg(self, sReplyMsg, sNonce, timestamp=None):
# 将企业回复用户的消息加密打包
# @param sReplyMsg: 企业号待回复用户的消息xml格式的字符串
# @param sTimeStamp: 时间戳可以自己生成也可以用URL参数的timestamp,如为None则自动用当前时间
# @param sNonce: 随机串可以自己生成也可以用URL参数的nonce
# sEncryptMsg: 加密后的可以直接回复用户的密文包括msg_signature, timestamp, nonce, encrypt的xml格式的字符串,
# return成功0sEncryptMsg,失败返回对应的错误码None
pc = Prpcrypt(self.key)
ret, encrypt = pc.encrypt(sReplyMsg, self.m_sReceiveId)
encrypt = encrypt.decode('utf8')
if ret != 0:
return ret, None
if timestamp is None:
timestamp = str(int(time.time()))
# 生成安全签名
sha1 = SHA1()
ret, signature = sha1.getSHA1(self.m_sToken, timestamp, sNonce, encrypt)
if ret != 0:
return ret, None
xmlParse = XMLParse()
return ret, xmlParse.generate(encrypt, signature, timestamp, sNonce)
def DecryptMsg(self, sPostData, sMsgSignature, sTimeStamp, sNonce):
# 检验消息的真实性,并且获取解密后的明文
# @param sMsgSignature: 签名串对应URL参数的msg_signature
# @param sTimeStamp: 时间戳对应URL参数的timestamp
# @param sNonce: 随机串对应URL参数的nonce
# @param sPostData: 密文对应POST请求的数据
# xml_content: 解密后的原文当return返回0时有效
# @return: 成功0失败返回对应的错误码
# 验证安全签名
xmlParse = XMLParse()
ret, encrypt = xmlParse.extract(sPostData)
if ret != 0:
return ret, None
sha1 = SHA1()
ret, signature = sha1.getSHA1(self.m_sToken, sTimeStamp, sNonce, encrypt)
if ret != 0:
return ret, None
if not signature == sMsgSignature:
return ierror.WXBizMsgCrypt_ValidateSignature_Error, None
pc = Prpcrypt(self.key)
ret, xml_content = pc.decrypt(encrypt, self.m_sReceiveId)
return ret, xml_content

View File

200
libs/wecom_api/api.py Normal file
View File

@@ -0,0 +1,200 @@
from quart import request
from .WXBizMsgCrypt3 import WXBizMsgCrypt
import httpx
from quart import Quart
import xml.etree.ElementTree as ET
from typing import Callable, Dict, Any
from .wecomevent import WecomEvent
class WecomClient():
def __init__(self,corpid:str,secret:str,token:str,EncodingAESKey:str,contacts_secret:str):
self.corpid = corpid
self.secret = secret
self.access_token_for_contacts =''
self.token = token
self.aes = EncodingAESKey
self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin'
self.access_token = ''
self.secret_for_contacts = contacts_secret
self.app = Quart(__name__)
self.wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid)
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'])
self._message_handlers = {
"example":[],
}
#access——token操作
async def check_access_token(self):
return bool(self.access_token and self.access_token.strip())
async def check_access_token_for_contacts(self):
return bool(self.access_token_for_contacts and self.access_token_for_contacts.strip())
async def get_access_token(self,secret):
url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corpid}&corpsecret={secret}'
async with httpx.AsyncClient() as client:
response = await client.get(url)
data = response.json()
if 'access_token' in data:
return data['access_token']
else:
raise Exception(f"未获取access token: {data}")
async def get_users(self):
if not self.check_access_token_for_contacts():
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
url = self.base_url+'/user/list_id?access_token='+self.access_token_for_contacts
async with httpx.AsyncClient() as client:
params = {
"cursor":"",
"limit":10000,
}
response = await client.post(url,json=params)
data = response.json()
if data['errcode'] == 0:
dept_users = data['dept_user']
userid = []
for user in dept_users:
userid.append(user["userid"])
return userid
else:
raise Exception("未获取用户")
async def send_to_all(self,content:str):
if not self.check_access_token_for_contacts():
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
url = self.base_url+'/message/send?access_token='+self.access_token_for_contacts
user_ids = await self.get_users()
user_ids_string = "|".join(user_ids)
async with httpx.AsyncClient() as client:
params = {
"touser" : user_ids_string,
"msgtype" : "text",
"agentid" : 1000002,
"text" : {
"content" : content,
},
"safe":0,
"enable_id_trans": 0,
"enable_duplicate_check": 0,
"duplicate_check_interval": 1800
}
response = await client.post(url,json=params)
data = response.json()
if data['errcode'] != 0:
raise Exception("Failed to send message: "+str(data))
async def send_private_msg(self,user_id:str, agent_id:int,content:str):
if not await self.check_access_token():
self.access_token = await self.get_access_token(self.secret)
url = self.base_url+'/message/send?access_token='+self.access_token
async with httpx.AsyncClient() as client:
params={
"touser" : user_id,
"msgtype" : "text",
"agentid" : agent_id,
"text" : {
"content" : content,
},
"safe":0,
"enable_id_trans": 0,
"enable_duplicate_check": 0,
"duplicate_check_interval": 1800
}
response = await client.post(url,json=params)
data = response.json()
if data['errcode'] != 0:
raise Exception("Failed to send message: "+str(data))
async def handle_callback_request(self):
"""
处理回调请求,包括 GET 验证和 POST 消息接收。
"""
try:
msg_signature = request.args.get("msg_signature")
timestamp = request.args.get("timestamp")
nonce = request.args.get("nonce")
if request.method == "GET":
echostr = request.args.get("echostr")
ret, reply_echo_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
if ret != 0:
raise Exception(f"验证失败,错误码: {ret}")
return reply_echo_str
elif request.method == "POST":
encrypt_msg = await request.data
ret, xml_msg = self.wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
if ret != 0:
raise Exception(f"消息解密失败,错误码: {ret}")
# 解析消息并处理
message_data = await self.get_message(xml_msg)
if message_data:
event = WecomEvent.from_payload(message_data) # 转换为 WecomEvent 对象
if event:
await self._handle_message(event)
return "success"
except Exception as e:
return f"Error processing request: {str(e)}", 400
async def run_task(self, host: str, port: int, *args, **kwargs):
"""
启动 Quart 应用。
"""
await self.app.run_task(host=host, port=port, *args, **kwargs)
def on_message(self, msg_type: str):
"""
注册消息类型处理器。
"""
def decorator(func: Callable[[WecomEvent], None]):
if msg_type not in self._message_handlers:
self._message_handlers[msg_type] = []
self._message_handlers[msg_type].append(func)
return func
return decorator
async def _handle_message(self, event: WecomEvent):
"""
处理消息事件。
"""
msg_type = event.type
if msg_type in self._message_handlers:
for handler in self._message_handlers[msg_type]:
await handler(event)
async def get_message(self, xml_msg: str) -> Dict[str, Any]:
"""
解析微信返回的 XML 消息并转换为字典。
"""
root = ET.fromstring(xml_msg)
message_data = {
"ToUserName": root.find("ToUserName").text,
"FromUserName": root.find("FromUserName").text,
"CreateTime": int(root.find("CreateTime").text),
"MsgType": root.find("MsgType").text,
"Content": root.find("Content").text if root.find("Content") is not None else None,
"MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None,
"AgentID": int(root.find("AgentID").text) if root.find("AgentID") is not None else None,
}
return message_data

20
libs/wecom_api/ierror.py Normal file
View File

@@ -0,0 +1,20 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*-
#########################################################################
# Author: jonyqin
# Created Time: Thu 11 Sep 2014 01:53:58 PM CST
# File Name: ierror.py
# Description:定义错误码含义
#########################################################################
WXBizMsgCrypt_OK = 0
WXBizMsgCrypt_ValidateSignature_Error = -40001
WXBizMsgCrypt_ParseXml_Error = -40002
WXBizMsgCrypt_ComputeSignature_Error = -40003
WXBizMsgCrypt_IllegalAesKey = -40004
WXBizMsgCrypt_ValidateCorpid_Error = -40005
WXBizMsgCrypt_EncryptAES_Error = -40006
WXBizMsgCrypt_DecryptAES_Error = -40007
WXBizMsgCrypt_IllegalBuffer = -40008
WXBizMsgCrypt_EncodeBase64_Error = -40009
WXBizMsgCrypt_DecodeBase64_Error = -40010
WXBizMsgCrypt_GenReturnXml_Error = -40011

View File

@@ -0,0 +1,172 @@
from typing import Dict, Any, Optional
class WecomEvent(dict):
"""
封装从企业微信收到的事件数据对象(字典),提供属性以获取其中的字段。
除 `type` 和 `detail_type` 属性对于任何事件都有效外,其它属性是否存在(若不存在则返回 `None`)依事件类型不同而不同。
"""
@staticmethod
def from_payload(payload: Dict[str, Any]) -> Optional["WecomEvent"]:
"""
从企业微信事件数据构造 `WecomEvent` 对象。
Args:
payload (Dict[str, Any]): 解密后的企业微信事件数据。
Returns:
Optional[WecomEvent]: 如果事件数据合法,则返回 WecomEvent 对象;否则返回 None。
"""
try:
event = WecomEvent(payload)
_ = event.type, event.detail_type # 确保必须字段存在
return event
except KeyError:
return None
@property
def type(self) -> str:
"""
事件类型,例如 "message""event""text" 等。
Returns:
str: 事件类型。
"""
return self.get("MsgType", "")
@property
def detail_type(self) -> str:
"""
事件详细类型,依 `type` 的不同而不同。例如:
- 消息事件: "text", "image", "voice", 等
- 事件通知: "subscribe", "unsubscribe", "click", 等
Returns:
str: 事件详细类型。
"""
if self.type == "event":
return self.get("Event", "")
return self.type
@property
def name(self) -> str:
"""
事件名,对于消息事件是 `type.detail_type`,对于其他事件是 `event_type`。
Returns:
str: 事件名。
"""
return f"{self.type}.{self.detail_type}"
@property
def user_id(self) -> Optional[str]:
"""
用户 ID例如消息的发送者或事件的触发者。
Returns:
Optional[str]: 用户 ID。
"""
return self.get("FromUserName")
@property
def agent_id(self) -> Optional[int]:
"""
机器人 ID仅在消息类型事件中存在。
Returns:
Optional[int]: 机器人 ID。
"""
return self.get("AgentID")
@property
def receiver_id(self) -> Optional[str]:
"""
接收者 ID例如机器人自身的企业微信 ID。
Returns:
Optional[str]: 接收者 ID。
"""
return self.get("ToUserName")
@property
def message_id(self) -> Optional[str]:
"""
消息 ID仅在消息类型事件中存在。
Returns:
Optional[str]: 消息 ID。
"""
return self.get("MsgId")
@property
def message(self) -> Optional[str]:
"""
消息内容,仅在消息类型事件中存在。
Returns:
Optional[str]: 消息内容。
"""
return self.get("Content")
@property
def media_id(self) -> Optional[str]:
"""
媒体文件 ID仅在图片、语音等消息类型中存在。
Returns:
Optional[str]: 媒体文件 ID。
"""
return self.get("MediaId")
@property
def timestamp(self) -> Optional[int]:
"""
事件发生的时间戳。
Returns:
Optional[int]: 时间戳。
"""
return self.get("CreateTime")
@property
def event_key(self) -> Optional[str]:
"""
事件的 Key 值,例如点击菜单时的 `EventKey`。
Returns:
Optional[str]: 事件 Key。
"""
return self.get("EventKey")
def __getattr__(self, key: str) -> Optional[Any]:
"""
允许通过属性访问数据中的任意字段。
Args:
key (str): 字段名。
Returns:
Optional[Any]: 字段值。
"""
return self.get(key)
def __setattr__(self, key: str, value: Any) -> None:
"""
允许通过属性设置数据中的任意字段。
Args:
key (str): 字段名。
value (Any): 字段值。
"""
self[key] = value
def __repr__(self) -> str:
"""
生成事件对象的字符串表示。
Returns:
str: 字符串表示。
"""
return f"<WecomEvent {super().__repr__()}>"

View File

@@ -44,10 +44,10 @@ class Query(pydantic.BaseModel):
launcher_type: LauncherTypes
"""会话类型platform处理阶段设置"""
launcher_id: int
launcher_id: typing.Union[int, str]
"""会话IDplatform处理阶段设置"""
sender_id: int
sender_id: typing.Union[int, str]
"""发送者IDplatform处理阶段设置"""
message_event: platform_events.MessageEvent
@@ -113,9 +113,9 @@ class Session(pydantic.BaseModel):
"""会话,一个 Session 对应一个 {launcher_type.value}_{launcher_id}"""
launcher_type: LauncherTypes
launcher_id: int
launcher_id: typing.Union[int, str]
sender_id: typing.Optional[int] = 0
sender_id: typing.Optional[typing.Union[int, str]] = 0
use_prompt_name: typing.Optional[str] = 'default'

View File

@@ -68,7 +68,7 @@ class Controller:
except Exception as e:
# traceback.print_exc()
self.ap.logger.error(f"控制器循环出错: {e}")
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
self.ap.logger.error(f"Traceback: {traceback.format_exc()}")
async def _check_output(self, query: entities.Query, result: pipeline_entities.StageProcessResult):
"""检查输出
@@ -163,29 +163,30 @@ class Controller:
async def process_query(self, query: entities.Query):
"""处理请求
"""
# ======== 触发 MessageReceived 事件 ========
event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_type(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
message_chain=query.message_chain,
query=query
)
)
if event_ctx.is_prevented_default():
return
self.ap.logger.debug(f"Processing query {query}")
try:
# ======== 触发 MessageReceived 事件 ========
event_type = events.PersonMessageReceived if query.launcher_type == entities.LauncherTypes.PERSON else events.GroupMessageReceived
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_type(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
message_chain=query.message_chain,
query=query
)
)
if event_ctx.is_prevented_default():
return
self.ap.logger.debug(f"Processing query {query}")
await self._execute_from_stage(0, query)
except Exception as e:
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={query.current_stage.inst_name} : {e}")
inst_name = query.current_stage.inst_name if query.current_stage else 'unknown'
self.ap.logger.error(f"处理请求时出错 query_id={query.query_id} stage={inst_name} : {e}")
self.ap.logger.debug(f"Traceback: {traceback.format_exc()}")
finally:
self.ap.logger.debug(f"Query {query} processed")

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
import asyncio
import typing
from ..core import entities
from ..platform import adapter as msadapter
@@ -29,8 +29,8 @@ class QueryPool:
async def add_query(
self,
launcher_type: entities.LauncherTypes,
launcher_id: int,
sender_id: int,
launcher_id: typing.Union[int, str],
sender_id: typing.Union[int, str],
message_event: platform_events.MessageEvent,
message_chain: platform_message.MessageChain,
adapter: msadapter.MessageSourceAdapter

View File

@@ -31,7 +31,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
pass
@abc.abstractmethod
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
async def require_access(self, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
"""进入处理流程
这个方法对等待是友好的,意味着算法可以实现在这里等待一段时间以控制速率。
@@ -46,7 +46,7 @@ class ReteLimitAlgo(metaclass=abc.ABCMeta):
raise NotImplementedError
@abc.abstractmethod
async def release_access(self, launcher_type: str, launcher_id: int):
async def release_access(self, launcher_type: str, launcher_id: typing.Union[int, str]):
"""退出处理流程
Args:

View File

@@ -1,6 +1,7 @@
from __future__ import annotations
import asyncio
import time
import typing
from .. import algo
# 固定窗口算法
@@ -29,7 +30,7 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
self.containers_lock = asyncio.Lock()
self.containers = {}
async def require_access(self, launcher_type: str, launcher_id: int) -> bool:
async def require_access(self, launcher_type: str, launcher_id: typing.Union[int, str]) -> bool:
# 加锁,找容器
container: SessionContainer = None
@@ -83,5 +84,5 @@ class FixedWindowAlgo(algo.ReteLimitAlgo):
# 返回True
return True
async def release_access(self, launcher_type: str, launcher_id: int):
async def release_access(self, launcher_type: str, launcher_id: typing.Union[int, str]):
pass

View File

@@ -52,7 +52,7 @@ class MessageSourceAdapter(metaclass=abc.ABCMeta):
self.config = config
self.ap = ap
async def send_message(
async def send_message(
self,
target_type: str,
target_id: str,

View File

@@ -37,7 +37,7 @@ class PlatformManager:
async def initialize(self):
from .sources import nakuru, aiocqhttp, qqbotpy
from .sources import nakuru, aiocqhttp, qqbotpy,wecom
async def on_friend_message(event: platform_events.FriendMessage, adapter: msadapter.MessageSourceAdapter):

View File

@@ -0,0 +1,193 @@
from __future__ import annotations
import typing
import asyncio
import traceback
import time
import datetime
import aiocqhttp
import aiohttp
from libs.wecom_api.api import WecomClient
from pkg.platform.adapter import MessageSourceAdapter
from pkg.platform.types import events as platform_events, message as platform_message
from libs.wecom_api.wecomevent import WecomEvent
from pkg.core import app
from .. import adapter
from ...pipeline.longtext.strategies import forward
from ...core import app
from ..types import message as platform_message
from ..types import events as platform_events
from ..types import entities as platform_entities
from ...command.errors import ParamNotEnoughError
class WecomMessageConverter(adapter.MessageConverter):
@staticmethod
async def yiri2target(message_chain:platform_message.MessageChain):
content=''
for msg in message_chain:
if type(msg) is platform_message.Plain:
content+=msg.text
return content
@staticmethod
async def target2yiri(message:str,message_id:int = -1):
yiri_msg_list = []
yiri_msg_list.append(
platform_message.Source(id = message_id,time = datetime.datetime.now())
)
yiri_msg_list.append(platform_message.Plain(text=message))
chain = platform_message.MessageChain(yiri_msg_list)
return chain
class WecomEventConverter:
@staticmethod
async def yiri2target(event:platform_events.Event,bot_account_id:int) -> WecomEvent:
content = await WecomMessageConverter.yiri2target(event.message_chain)
if type(event) is platform_events.GroupMessage:
pass
if type(event) is platform_events.FriendMessage:
payload = {
"MsgType": "text",
"Content": content,
"FromUserName": event.sender.id,
"ToUserName": bot_account_id,
"CreateTime": int(datetime.datetime.now().timestamp()),
"AgentID": event.sender.nickname
}
wecom_event = WecomEvent.from_payload(payload=payload)
if not wecom_event:
raise ValueError("无法从 message_data 构造 WecomEvent 对象")
return wecom_event
@staticmethod
async def target2yiri(event: WecomEvent):
"""
将 WecomEvent 转换为平台的 FriendMessage 对象。
Args:
event (WecomEvent): 企业微信事件。
Returns:
platform_events.FriendMessage: 转换后的 FriendMessage 对象。
"""
# 转换消息链
yiri_chain = await WecomMessageConverter.target2yiri(
event.message, event.message_id
)
# 判断消息类型并进行转换
# if event.message_type == "private": 默认消息都是从好友发出
friend = platform_entities.Friend(
id=event.user_id,
nickname=str(event.agent_id),
remark="",
)
return platform_events.FriendMessage(
sender=friend,
message_chain=yiri_chain,
time=event.timestamp
)
@adapter.adapter_class("wecom")
class WecomeAdapter(adapter.MessageSourceAdapter):
bot:WecomClient
ap:app.Application
bot_account_id:str
message_converter:WecomMessageConverter = WecomMessageConverter()
event_converter:WecomEventConverter = WecomEventConverter()
config:dict
ap:app.Application
def __init__(self, config: dict, ap:app.Application):
self.config = config
#这里需要对config里的内容换成企业微信的config。是config:corpid,token......
self.ap = ap
required_keys = ["corpid","secret","token","EncodingAESKey","contacts_secret"]
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise ParamNotEnoughError("企业微信缺少相关配置项,请查看文档或联系管理员")
self.bot = WecomClient(
corpid=config['corpid'],
secret=config['secret'],
token=config['token'],
EncodingAESKey=config['EncodingAESKey'],
contacts_secret=config['contacts_secret']
)
async def reply_message(self,message_source:platform_events.MessageEvent,message:platform_message.MessageChain,
quote_origin:bool=False,
):
Wecom_event = await WecomEventConverter.yiri2target(message_source,self.bot_account_id)
Wecom_msg = await WecomMessageConverter.yiri2target(message)
# message_converter传回一个消息str
user_id = Wecom_event.user_id
agent_id = Wecom_event.agent_id
return await self.bot.send_private_msg(user_id=user_id,agent_id=agent_id,content=Wecom_msg)
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
pass
def register_listener(
self,
event_type:typing.Type[platform_events.Event],
callback:typing.Callable[[platform_events.Event,adapter.MessageSourceAdapter],None],
):
async def on_message(event:WecomEvent):
self.bot_account_id = event.receiver_id
try:
return await callback(await self.event_converter.target2yiri(event),self)
except:
traceback.print_exc()
if event_type == platform_events.FriendMessage:
self.bot.on_message("text")(on_message)
elif event_type == platform_events.GroupMessage:
pass
async def run_async(self):
async def shutdown_trigger_placeholder():
while True:
await asyncio.sleep(1)
await self.bot.run_task(host=self.config['host'],port=self.config['port'],shutdown_trigger=shutdown_trigger_placeholder)
async def kill(self) -> bool:
return False
async def unregister_listener(self, event_type: type, callback: typing.Callable[[platform_events.Event, MessageSourceAdapter], None]):
return super().unregister_listener(event_type, callback)

View File

@@ -25,7 +25,7 @@ class Entity(pydantic.BaseModel):
class Friend(Entity):
"""好友。"""
id: int
id: typing.Union[int, str]
"""QQ 号。"""
nickname: typing.Optional[str]
"""昵称。"""
@@ -52,7 +52,7 @@ class Permission(str, Enum):
class Group(Entity):
"""群。"""
id: int
id: typing.Union[int, str]
"""群号。"""
name: str
"""群名称。"""
@@ -67,7 +67,7 @@ class Group(Entity):
class GroupMember(Entity):
"""群成员。"""
id: int
id: typing.Union[int, str]
"""QQ 号。"""
member_name: str
"""群成员名称。"""
@@ -92,7 +92,7 @@ class GroupMember(Entity):
class Client(Entity):
"""来自其他客户端的用户。"""
id: int
id: typing.Union[int, str]
"""识别 id。"""
platform: str
"""来源平台。"""
@@ -105,7 +105,7 @@ class Client(Entity):
class Subject(pydantic.BaseModel):
"""另一种实体类型表示。"""
id: int
id: typing.Union[int, str]
"""QQ 号或群号。"""
kind: typing.Literal['Friend', 'Group', 'Stranger']
"""类型。"""

View File

@@ -485,11 +485,11 @@ class Quote(MessageComponent):
"""消息组件类型。"""
id: typing.Optional[int] = None
"""被引用回复的原消息的 message_id。"""
group_id: typing.Optional[int] = None
group_id: typing.Optional[typing.Union[int, str]] = None
"""被引用回复的原消息所接收的群号当为好友消息时为0。"""
sender_id: typing.Optional[int] = None
sender_id: typing.Optional[typing.Union[int, str]] = None
"""被引用回复的原消息的发送者的QQ号。"""
target_id: typing.Optional[int] = None
target_id: typing.Optional[typing.Union[int, str]] = None
"""被引用回复的原消息的接收者者的QQ号或群号"""
origin: MessageChain
"""被引用回复的原消息的消息链对象。"""
@@ -749,7 +749,7 @@ class Voice(MessageComponent):
class ForwardMessageNode(pydantic.BaseModel):
"""合并转发中的一条消息。"""
sender_id: typing.Optional[int] = None
sender_id: typing.Optional[typing.Union[int, str]] = None
"""发送人QQ号。"""
sender_name: typing.Optional[str] = None
"""显示名称。"""

View File

@@ -25,10 +25,10 @@ class PersonMessageReceived(BaseEventModel):
launcher_type: str
"""发起对象类型(group/person)"""
launcher_id: int
launcher_id: typing.Union[int, str]
"""发起对象ID(群号/QQ号)"""
sender_id: int
sender_id: typing.Union[int, str]
"""发送者ID(QQ号)"""
message_chain: platform_message.MessageChain
@@ -39,9 +39,9 @@ class GroupMessageReceived(BaseEventModel):
launcher_type: str
launcher_id: int
launcher_id: typing.Union[int, str]
sender_id: int
sender_id: typing.Union[int, str]
message_chain: platform_message.MessageChain
@@ -51,9 +51,9 @@ class PersonNormalMessageReceived(BaseEventModel):
launcher_type: str
launcher_id: int
launcher_id: typing.Union[int, str]
sender_id: int
sender_id: typing.Union[int, str]
text_message: str
@@ -69,9 +69,9 @@ class PersonCommandSent(BaseEventModel):
launcher_type: str
launcher_id: int
launcher_id: typing.Union[int, str]
sender_id: int
sender_id: typing.Union[int, str]
command: str
@@ -93,9 +93,9 @@ class GroupNormalMessageReceived(BaseEventModel):
launcher_type: str
launcher_id: int
launcher_id: typing.Union[int, str]
sender_id: int
sender_id: typing.Union[int, str]
text_message: str
@@ -111,9 +111,9 @@ class GroupCommandSent(BaseEventModel):
launcher_type: str
launcher_id: int
launcher_id: typing.Union[int, str]
sender_id: int
sender_id: typing.Union[int, str]
command: str
@@ -135,9 +135,9 @@ class NormalMessageResponded(BaseEventModel):
launcher_type: str
launcher_id: int
launcher_id: typing.Union[int, str]
sender_id: int
sender_id: typing.Union[int, str]
session: core_entities.Session
"""会话对象"""

View File

@@ -24,6 +24,7 @@ aiofiles
aioshutil
argon2-cffi
pyjwt
pycryptodome
# indirect
taskgroup==0.0.0a4

View File

@@ -24,6 +24,17 @@
"public_guild_messages",
"direct_message"
]
},
{
"adapter":"wecom",
"enable":false,
"host":"0.0.0.0",
"port":5001,
"corpid":"",
"secret":"",
"token":"",
"EncodingAESKey":"",
"contacts_secret":""
}
],
"track-function-calls": true,