style: introduce ruff as linter and formatter (#1356)

* style: remove necessary imports

* style: fix F841

* style: fix F401

* style: fix F811

* style: fix E402

* style: fix E721

* style: fix E722

* style: fix E722

* style: fix F541

* style: ruff format

* style: all passed

* style: add ruff in deps

* style: more ignores in ruff.toml

* style: add pre-commit
This commit is contained in:
Junyan Qin (Chin)
2025-04-29 17:24:07 +08:00
committed by GitHub
parent 09e70d70e9
commit 209f16af76
240 changed files with 5307 additions and 4689 deletions

View File

@@ -1,2 +1,4 @@
from .v1 import client
from .v1 import errors
from .v1 import client as client
from .v1 import errors as errors
__all__ = ['client', 'errors']

View File

@@ -8,25 +8,33 @@ import json
class TestDifyClient:
async def test_chat_messages(self):
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL"))
cln = client.AsyncDifyServiceClient(
api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL')
)
async for chunk in cln.chat_messages(inputs={}, query="调用工具查看现在几点?", user="test"):
async for chunk in cln.chat_messages(
inputs={}, query='调用工具查看现在几点?', user='test'
):
print(json.dumps(chunk, ensure_ascii=False, indent=4))
async def test_upload_file(self):
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL"))
cln = client.AsyncDifyServiceClient(
api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL')
)
file_bytes = open("img.png", "rb").read()
file_bytes = open('img.png', 'rb').read()
print(type(file_bytes))
file = ("img2.png", file_bytes, "image/png")
file = ('img2.png', file_bytes, 'image/png')
resp = await cln.upload_file(file=file, user="test")
resp = await cln.upload_file(file=file, user='test')
print(json.dumps(resp, ensure_ascii=False, indent=4))
async def test_workflow_run(self):
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL"))
cln = client.AsyncDifyServiceClient(
api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL')
)
# resp = await cln.workflow_run(inputs={}, user="test")
# # print(json.dumps(resp, ensure_ascii=False, indent=4))
@@ -34,11 +42,12 @@ class TestDifyClient:
chunks = []
ignored_events = ['text_chunk']
async for chunk in cln.workflow_run(inputs={}, user="test"):
async for chunk in cln.workflow_run(inputs={}, user='test'):
if chunk['event'] in ignored_events:
continue
chunks.append(chunk)
print(json.dumps(chunks, ensure_ascii=False, indent=4))
if __name__ == "__main__":
if __name__ == '__main__':
asyncio.run(TestDifyClient().test_chat_messages())

View File

@@ -12,11 +12,11 @@ class AsyncDifyServiceClient:
api_key: str
base_url: str
def __init__(
self,
api_key: str,
base_url: str = "https://api.dify.ai/v1",
base_url: str = 'https://api.dify.ai/v1',
) -> None:
self.api_key = api_key
self.base_url = base_url
@@ -26,76 +26,81 @@ class AsyncDifyServiceClient:
inputs: dict[str, typing.Any],
query: str,
user: str,
response_mode: str = "streaming", # 当前不支持 blocking
conversation_id: str = "",
response_mode: str = 'streaming', # 当前不支持 blocking
conversation_id: str = '',
files: list[dict[str, typing.Any]] = [],
timeout: float = 30.0,
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
"""发送消息"""
if response_mode != "streaming":
raise DifyAPIError("当前仅支持 streaming 模式")
if response_mode != 'streaming':
raise DifyAPIError('当前仅支持 streaming 模式')
async with httpx.AsyncClient(
base_url=self.base_url,
trust_env=True,
timeout=timeout,
) as client:
async with client.stream(
"POST",
"/chat-messages",
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
'POST',
'/chat-messages',
headers={
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json',
},
json={
"inputs": inputs,
"query": query,
"user": user,
"response_mode": response_mode,
"conversation_id": conversation_id,
"files": files,
'inputs': inputs,
'query': query,
'user': user,
'response_mode': response_mode,
'conversation_id': conversation_id,
'files': files,
},
) as r:
async for chunk in r.aiter_lines():
if r.status_code != 200:
raise DifyAPIError(f"{r.status_code} {chunk}")
if chunk.strip() == "":
raise DifyAPIError(f'{r.status_code} {chunk}')
if chunk.strip() == '':
continue
if chunk.startswith("data:"):
if chunk.startswith('data:'):
yield json.loads(chunk[5:])
async def workflow_run(
self,
inputs: dict[str, typing.Any],
user: str,
response_mode: str = "streaming", # 当前不支持 blocking
response_mode: str = 'streaming', # 当前不支持 blocking
files: list[dict[str, typing.Any]] = [],
timeout: float = 30.0,
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
"""运行工作流"""
if response_mode != "streaming":
raise DifyAPIError("当前仅支持 streaming 模式")
if response_mode != 'streaming':
raise DifyAPIError('当前仅支持 streaming 模式')
async with httpx.AsyncClient(
base_url=self.base_url,
trust_env=True,
timeout=timeout,
) as client:
async with client.stream(
"POST",
"/workflows/run",
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
'POST',
'/workflows/run',
headers={
'Authorization': f'Bearer {self.api_key}',
'Content-Type': 'application/json',
},
json={
"inputs": inputs,
"user": user,
"response_mode": response_mode,
"files": files,
'inputs': inputs,
'user': user,
'response_mode': response_mode,
'files': files,
},
) as r:
async for chunk in r.aiter_lines():
if r.status_code != 200:
raise DifyAPIError(f"{r.status_code} {chunk}")
if chunk.strip() == "":
raise DifyAPIError(f'{r.status_code} {chunk}')
if chunk.strip() == '':
continue
if chunk.startswith("data:"):
if chunk.startswith('data:'):
yield json.loads(chunk[5:])
async def upload_file(
@@ -112,15 +117,15 @@ class AsyncDifyServiceClient:
) as client:
# multipart/form-data
response = await client.post(
"/files/upload",
headers={"Authorization": f"Bearer {self.api_key}"},
'/files/upload',
headers={'Authorization': f'Bearer {self.api_key}'},
files={
"file": file,
"user": (None, user),
'file': file,
'user': (None, user),
},
)
if response.status_code != 201:
raise DifyAPIError(f"{response.status_code} {response.text}")
raise DifyAPIError(f'{response.status_code} {response.text}')
return response.json()

View File

@@ -7,11 +7,11 @@ import os
class TestDifyClient:
async def test_chat_messages(self):
cln = client.DifyClient(api_key=os.getenv("DIFY_API_KEY"))
cln = client.DifyClient(api_key=os.getenv('DIFY_API_KEY'))
resp = await cln.chat_messages(inputs={}, query="Who are you?", user_id="test")
resp = await cln.chat_messages(inputs={}, query='Who are you?', user_id='test')
print(resp)
if __name__ == "__main__":
if __name__ == '__main__':
asyncio.run(TestDifyClient().test_chat_messages())

View File

@@ -1,8 +1,8 @@
import asyncio
import json
import dingtalk_stream
from dingtalk_stream import AckMessage
class EchoTextHandler(dingtalk_stream.ChatbotHandler):
def __init__(self, client):
self.msg_id = ''
@@ -10,6 +10,7 @@ class EchoTextHandler(dingtalk_stream.ChatbotHandler):
self.client = client # 用于更新 DingTalkClient 中的 incoming_message
"""处理钉钉消息"""
async def process(self, callback: dingtalk_stream.CallbackMessage):
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
if incoming_message.message_id != self.msg_id:
@@ -26,6 +27,8 @@ class EchoTextHandler(dingtalk_stream.ChatbotHandler):
return self.incoming_message
async def get_dingtalk_client(client_id, client_secret):
from api import DingTalkClient # 延迟导入,避免循环导入
return DingTalkClient(client_id, client_secret)

View File

@@ -10,7 +10,9 @@ import traceback
class DingTalkClient:
def __init__(self, client_id: str, client_secret: str,robot_name:str,robot_code:str):
def __init__(
self, client_id: str, client_secret: str, robot_name: str, robot_code: str
):
"""初始化 WebSocket 连接并自动启动"""
self.credential = dingtalk_stream.Credential(client_id, client_secret)
self.client = dingtalk_stream.DingTalkStreamClient(self.credential)
@@ -18,106 +20,91 @@ class DingTalkClient:
self.secret = client_secret
# 在 DingTalkClient 中传入自己作为参数,避免循环导入
self.EchoTextHandler = EchoTextHandler(self)
self.client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler)
self.client.register_callback_handler(
dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler
)
self._message_handlers = {
"example":[],
'example': [],
}
self.access_token = ''
self.robot_name = robot_name
self.robot_code = robot_code
self.access_token_expiry_time = ''
async def get_access_token(self):
url = "https://api.dingtalk.com/v1.0/oauth2/accessToken"
headers = {
"Content-Type": "application/json"
}
data = {
"appKey": self.key,
"appSecret": self.secret
}
url = 'https://api.dingtalk.com/v1.0/oauth2/accessToken'
headers = {'Content-Type': 'application/json'}
data = {'appKey': self.key, 'appSecret': self.secret}
async with httpx.AsyncClient() as client:
try:
response = await client.post(url,json=data,headers=headers)
response = await client.post(url, json=data, headers=headers)
if response.status_code == 200:
response_data = response.json()
self.access_token = response_data.get("accessToken")
expires_in = int(response_data.get("expireIn",7200))
self.access_token = response_data.get('accessToken')
expires_in = int(response_data.get('expireIn', 7200))
self.access_token_expiry_time = time.time() + expires_in - 60
except Exception as e:
raise Exception(e)
async def is_token_expired(self):
"""检查token是否过期"""
if self.access_token_expiry_time is None:
return True
return time.time() > self.access_token_expiry_time
async def check_access_token(self):
if not self.access_token or await self.is_token_expired():
return False
return bool(self.access_token and self.access_token.strip())
async def download_image(self,download_code:str):
async def download_image(self, download_code: str):
if not await self.check_access_token():
await self.get_access_token()
url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download'
params = {
"downloadCode":download_code,
"robotCode":self.robot_code
}
headers ={
"x-acs-dingtalk-access-token": self.access_token
}
params = {'downloadCode': download_code, 'robotCode': self.robot_code}
headers = {'x-acs-dingtalk-access-token': self.access_token}
async with httpx.AsyncClient() as client:
response = await client.post(url, headers=headers, json=params)
if response.status_code == 200:
result = response.json()
download_url = result.get("downloadUrl")
download_url = result.get('downloadUrl')
else:
raise Exception(f"Error: {response.status_code}, {response.text}")
raise Exception(f'Error: {response.status_code}, {response.text}')
if download_url:
return await self.download_url_to_base64(download_url)
async def download_url_to_base64(self,download_url):
async def download_url_to_base64(self, download_url):
async with httpx.AsyncClient() as client:
response = await client.get(download_url)
if response.status_code == 200:
file_bytes = response.content
base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式
base64_str = base64.b64encode(file_bytes).decode(
'utf-8'
) # 返回字符串格式
return base64_str
else:
raise Exception("获取文件失败")
async def get_audio_url(self,download_code:str):
raise Exception('获取文件失败')
async def get_audio_url(self, download_code: str):
if not await self.check_access_token():
await self.get_access_token()
url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download'
params = {
"downloadCode":download_code,
"robotCode":self.robot_code
}
headers ={
"x-acs-dingtalk-access-token": self.access_token
}
params = {'downloadCode': download_code, 'robotCode': self.robot_code}
headers = {'x-acs-dingtalk-access-token': self.access_token}
async with httpx.AsyncClient() as client:
response = await client.post(url, headers=headers, json=params)
if response.status_code == 200:
result = response.json()
download_url = result.get("downloadUrl")
download_url = result.get('downloadUrl')
if download_url:
return await self.download_url_to_base64(download_url)
else:
raise Exception("获取音频失败")
raise Exception('获取音频失败')
else:
raise Exception(f"Error: {response.status_code}, {response.text}")
raise Exception(f'Error: {response.status_code}, {response.text}')
async def update_incoming_message(self, message):
"""异步更新 DingTalkClient 中的 incoming_message"""
message_data = await self.get_message(message)
@@ -125,24 +112,21 @@ class DingTalkClient:
event = DingTalkEvent.from_payload(message_data)
if event:
await self._handle_message(event)
async def send_message(self,content:str,incoming_message):
self.EchoTextHandler.reply_text(content,incoming_message)
async def send_message(self, content: str, incoming_message):
self.EchoTextHandler.reply_text(content, incoming_message)
async def get_incoming_message(self):
"""获取收到的消息"""
return await self.EchoTextHandler.get_incoming_message()
def on_message(self, msg_type: str):
def decorator(func: Callable[[DingTalkEvent], None]):
if msg_type not in self._message_handlers:
self._message_handlers[msg_type] = []
self._message_handlers[msg_type].append(func)
return func
return decorator
async def _handle_message(self, event: DingTalkEvent):
@@ -154,40 +138,44 @@ class DingTalkClient:
for handler in self._message_handlers[msg_type]:
await handler(event)
async def get_message(self,incoming_message:dingtalk_stream.chatbot.ChatbotMessage):
async def get_message(
self, incoming_message: dingtalk_stream.chatbot.ChatbotMessage
):
try:
# print(json.dumps(incoming_message.to_dict(), indent=4, ensure_ascii=False))
message_data = {
"IncomingMessage":incoming_message,
'IncomingMessage': incoming_message,
}
if str(incoming_message.conversation_type) == '1':
message_data["conversation_type"] = 'FriendMessage'
message_data['conversation_type'] = 'FriendMessage'
elif str(incoming_message.conversation_type) == '2':
message_data["conversation_type"] = 'GroupMessage'
message_data['conversation_type'] = 'GroupMessage'
if incoming_message.message_type == 'richText':
data = incoming_message.rich_text_content.to_dict()
for item in data['richText']:
if 'text' in item:
message_data["Content"] = item['text']
message_data['Content'] = item['text']
if incoming_message.get_image_list()[0]:
message_data["Picture"] = await self.download_image(incoming_message.get_image_list()[0])
message_data["Type"] = 'text'
message_data['Picture'] = await self.download_image(
incoming_message.get_image_list()[0]
)
message_data['Type'] = 'text'
elif incoming_message.message_type == 'text':
message_data['Content'] = incoming_message.get_text_list()[0]
message_data["Type"] = 'text'
message_data['Type'] = 'text'
elif incoming_message.message_type == 'picture':
message_data['Picture'] = await self.download_image(incoming_message.get_image_list()[0])
message_data['Picture'] = await self.download_image(
incoming_message.get_image_list()[0]
)
message_data['Type'] = 'image'
elif incoming_message.message_type == 'audio':
message_data['Audio'] = await self.get_audio_url(incoming_message.to_dict()['content']['downloadCode'])
message_data['Audio'] = await self.get_audio_url(
incoming_message.to_dict()['content']['downloadCode']
)
message_data['Type'] = 'audio'
@@ -196,56 +184,55 @@ class DingTalkClient:
# print("message_data:", json.dumps(copy_message_data, indent=4, ensure_ascii=False))
except Exception:
traceback.print_exc()
return message_data
async def send_proactive_message_to_one(self,target_id:str,content:str):
async def send_proactive_message_to_one(self, target_id: str, content: str):
if not await self.check_access_token():
await self.get_access_token()
url = 'https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend'
headers ={
"x-acs-dingtalk-access-token":self.access_token,
"Content-Type":"application/json",
headers = {
'x-acs-dingtalk-access-token': self.access_token,
'Content-Type': 'application/json',
}
data ={
"robotCode":self.robot_code,
"userIds":[target_id],
"msgKey": "sampleText",
"msgParam": json.dumps({"content":content}),
data = {
'robotCode': self.robot_code,
'userIds': [target_id],
'msgKey': 'sampleText',
'msgParam': json.dumps({'content': content}),
}
try:
async with httpx.AsyncClient() as client:
response = await client.post(url,headers=headers,json=data)
await client.post(url, headers=headers, json=data)
except Exception:
traceback.print_exc()
async def send_proactive_message_to_group(self,target_id:str,content:str):
async def send_proactive_message_to_group(self, target_id: str, content: str):
if not await self.check_access_token():
await self.get_access_token()
url = 'https://api.dingtalk.com/v1.0/robot/groupMessages/send'
headers ={
"x-acs-dingtalk-access-token":self.access_token,
"Content-Type":"application/json",
headers = {
'x-acs-dingtalk-access-token': self.access_token,
'Content-Type': 'application/json',
}
data ={
"robotCode":self.robot_code,
"openConversationId":target_id,
"msgKey": "sampleText",
"msgParam": json.dumps({"content":content}),
data = {
'robotCode': self.robot_code,
'openConversationId': target_id,
'msgKey': 'sampleText',
'msgParam': json.dumps({'content': content}),
}
try:
async with httpx.AsyncClient() as client:
response = await client.post(url,headers=headers,json=data)
await client.post(url, headers=headers, json=data)
except Exception:
traceback.print_exc()
async def start(self):
"""启动 WebSocket 连接,监听消息"""
await self.client.start()
await self.client.start()

View File

@@ -1,41 +1,39 @@
from typing import Dict, Any, Optional
import dingtalk_stream
class DingTalkEvent(dict):
@staticmethod
def from_payload(payload: Dict[str, Any]) -> Optional["DingTalkEvent"]:
def from_payload(payload: Dict[str, Any]) -> Optional['DingTalkEvent']:
try:
event = DingTalkEvent(payload)
return event
except KeyError:
return None
@property
def content(self):
return self.get("Content","")
@property
def incoming_message(self) -> Optional["dingtalk_stream.chatbot.ChatbotMessage"]:
return self.get("IncomingMessage")
def content(self):
return self.get('Content', '')
@property
def incoming_message(self) -> Optional['dingtalk_stream.chatbot.ChatbotMessage']:
return self.get('IncomingMessage')
@property
def type(self):
return self.get("Type","")
return self.get('Type', '')
@property
def picture(self):
return self.get("Picture","")
return self.get('Picture', '')
@property
def audio(self):
return self.get("Audio","")
return self.get('Audio', '')
@property
def conversation(self):
return self.get("conversation_type","")
return self.get('conversation_type', '')
def __getattr__(self, key: str) -> Optional[Any]:
"""
@@ -66,4 +64,4 @@ class DingTalkEvent(dict):
Returns:
str: 字符串表示。
"""
return f"<DingTalkEvent {super().__repr__()}>"
return f'<DingTalkEvent {super().__repr__()}>'

View File

@@ -1,20 +1,14 @@
# 微信公众号的加解密算法与企业微信一样,所以直接使用企业微信的加解密算法文件
from collections import deque
import time
import traceback
from ..wecom_api.WXBizMsgCrypt3 import WXBizMsgCrypt
import xml.etree.ElementTree as ET
from quart import Quart,request
from quart import Quart, request
import hashlib
from typing import Callable, Dict, Any
from typing import Callable
from .oaevent import OAEvent
import httpx
import asyncio
import time
import xml.etree.ElementTree as ET
from pkg.platform.sources import officialaccount as oa
xml_template = """
@@ -28,9 +22,8 @@ xml_template = """
"""
class OAClient():
def __init__(self,token:str,EncodingAESKey:str,AppID:str,Appsecret:str):
class OAClient:
def __init__(self, token: str, EncodingAESKey: str, AppID: str, Appsecret: str):
self.token = token
self.aes = EncodingAESKey
self.appid = AppID
@@ -38,121 +31,130 @@ class OAClient():
self.base_url = 'https://api.weixin.qq.com'
self.access_token = ''
self.app = Quart(__name__)
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'])
self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['GET', 'POST'],
)
self._message_handlers = {
"example":[],
'example': [],
}
self.access_token_expiry_time = None
self.msg_id_map = {}
self.generated_content = {}
async def handle_callback_request(self):
try:
# 每隔100毫秒查询是否生成ai回答
start_time = time.time()
signature = request.args.get("signature", "")
timestamp = request.args.get("timestamp", "")
nonce = request.args.get("nonce", "")
echostr = request.args.get("echostr", "")
msg_signature = request.args.get("msg_signature","")
signature = request.args.get('signature', '')
timestamp = request.args.get('timestamp', '')
nonce = request.args.get('nonce', '')
echostr = request.args.get('echostr', '')
msg_signature = request.args.get('msg_signature', '')
if msg_signature is None:
raise Exception("msg_signature不在请求体中")
raise Exception('msg_signature不在请求体中')
if request.method == 'GET':
# 校验签名
check_str = "".join(sorted([self.token, timestamp, nonce]))
check_signature = hashlib.sha1(check_str.encode("utf-8")).hexdigest()
check_str = ''.join(sorted([self.token, timestamp, nonce]))
check_signature = hashlib.sha1(check_str.encode('utf-8')).hexdigest()
if check_signature == signature:
return echostr # 验证成功返回echostr
else:
raise Exception("拒绝请求")
elif request.method == "POST":
raise Exception('拒绝请求')
elif request.method == 'POST':
encryt_msg = await request.data
wxcpt = WXBizMsgCrypt(self.token,self.aes,self.appid)
ret,xml_msg = wxcpt.DecryptMsg(encryt_msg,msg_signature,timestamp,nonce)
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
ret, xml_msg = wxcpt.DecryptMsg(
encryt_msg, msg_signature, timestamp, nonce
)
xml_msg = xml_msg.decode('utf-8')
if ret != 0:
raise Exception("消息解密失败")
raise Exception('消息解密失败')
message_data = await self.get_message(xml_msg)
if message_data :
if message_data:
event = OAEvent.from_payload(message_data)
if event:
await self._handle_message(event)
root = ET.fromstring(xml_msg)
from_user = root.find("FromUserName").text # 发送者
to_user = root.find("ToUserName").text # 机器人
from_user = root.find('FromUserName').text # 发送者
to_user = root.find('ToUserName').text # 机器人
timeout = 4.80
interval = 0.1
while True:
content = self.generated_content.pop(message_data["MsgId"], None)
content = self.generated_content.pop(message_data['MsgId'], None)
if content:
response_xml = xml_template.format(
to_user=from_user,
from_user=to_user,
create_time=int(time.time()),
content = content
content=content,
)
return response_xml
if time.time() - start_time >= timeout:
break
await asyncio.sleep(interval)
if self.msg_id_map.get(message_data["MsgId"], 1) == 3:
if self.msg_id_map.get(message_data['MsgId'], 1) == 3:
# response_xml = xml_template.format(
# to_user=from_user,
# from_user=to_user,
# create_time=int(time.time()),
# content = "请求失效暂不支持公众号超过15秒的请求如有需求请联系 LangBot 团队。"
# )
print("请求失效暂不支持公众号超过15秒的请求如有需求请联系 LangBot 团队。")
print(
'请求失效暂不支持公众号超过15秒的请求如有需求请联系 LangBot 团队。'
)
return ''
except Exception as e:
except Exception:
traceback.print_exc()
async def get_message(self, xml_msg: str):
root = ET.fromstring(xml_msg)
message_data = {
"ToUserName": root.find("ToUserName").text,
"FromUserName": root.find("FromUserName").text,
"CreateTime": int(root.find("CreateTime").text),
"MsgType": root.find("MsgType").text,
"Content": root.find("Content").text if root.find("Content") is not None else None,
"MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None,
'ToUserName': root.find('ToUserName').text,
'FromUserName': root.find('FromUserName').text,
'CreateTime': int(root.find('CreateTime').text),
'MsgType': root.find('MsgType').text,
'Content': root.find('Content').text
if root.find('Content') is not None
else None,
'MsgId': int(root.find('MsgId').text)
if root.find('MsgId') is not None
else None,
}
return message_data
async def run_task(self, host: str, port: int, *args, **kwargs):
"""
启动 Quart 应用。
"""
await self.app.run_task(host=host, port=port, *args, **kwargs)
def on_message(self, msg_type: str):
"""
注册消息类型处理器。
"""
def decorator(func: Callable[[OAEvent], None]):
if msg_type not in self._message_handlers:
self._message_handlers[msg_type] = []
self._message_handlers[msg_type].append(func)
return func
return decorator
async def _handle_message(self, event: OAEvent):
@@ -170,14 +172,19 @@ class OAClient():
for handler in self._message_handlers[msg_type]:
await handler(event)
async def set_message(self,msg_id:int,content:str):
async def set_message(self, msg_id: int, content: str):
self.generated_content[msg_id] = content
class OAClientForLongerResponse():
def __init__(self,token:str,EncodingAESKey:str,AppID:str,Appsecret:str,LoadingMessage:str):
class OAClientForLongerResponse:
def __init__(
self,
token: str,
EncodingAESKey: str,
AppID: str,
Appsecret: str,
LoadingMessage: str,
):
self.token = token
self.aes = EncodingAESKey
self.appid = AppID
@@ -185,9 +192,14 @@ class OAClientForLongerResponse():
self.base_url = 'https://api.weixin.qq.com'
self.access_token = ''
self.app = Quart(__name__)
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'])
self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['GET', 'POST'],
)
self._message_handlers = {
"example":[],
'example': [],
}
self.access_token_expiry_time = None
self.loading_message = LoadingMessage
@@ -196,50 +208,55 @@ class OAClientForLongerResponse():
async def handle_callback_request(self):
try:
start_time = time.time()
signature = request.args.get("signature", "")
timestamp = request.args.get("timestamp", "")
nonce = request.args.get("nonce", "")
echostr = request.args.get("echostr", "")
msg_signature = request.args.get("msg_signature", "")
signature = request.args.get('signature', '')
timestamp = request.args.get('timestamp', '')
nonce = request.args.get('nonce', '')
echostr = request.args.get('echostr', '')
msg_signature = request.args.get('msg_signature', '')
if msg_signature is None:
raise Exception("msg_signature不在请求体中")
raise Exception('msg_signature不在请求体中')
if request.method == 'GET':
check_str = "".join(sorted([self.token, timestamp, nonce]))
check_signature = hashlib.sha1(check_str.encode("utf-8")).hexdigest()
return echostr if check_signature == signature else "拒绝请求"
check_str = ''.join(sorted([self.token, timestamp, nonce]))
check_signature = hashlib.sha1(check_str.encode('utf-8')).hexdigest()
return echostr if check_signature == signature else '拒绝请求'
elif request.method == "POST":
elif request.method == 'POST':
encryt_msg = await request.data
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce)
ret, xml_msg = wxcpt.DecryptMsg(
encryt_msg, msg_signature, timestamp, nonce
)
xml_msg = xml_msg.decode('utf-8')
if ret != 0:
raise Exception("消息解密失败")
raise Exception('消息解密失败')
# 解析 XML
root = ET.fromstring(xml_msg)
from_user = root.find("FromUserName").text
to_user = root.find("ToUserName").text
if self.msg_queue.get(from_user) and self.msg_queue[from_user][0]["content"]:
from_user = root.find('FromUserName').text
to_user = root.find('ToUserName').text
if (
self.msg_queue.get(from_user)
and self.msg_queue[from_user][0]['content']
):
queue_top = self.msg_queue[from_user].pop(0)
queue_content = queue_top["content"]
queue_content = queue_top['content']
# 弹出用户消息
if self.user_msg_queue.get(from_user) and self.user_msg_queue[from_user]:
if (
self.user_msg_queue.get(from_user)
and self.user_msg_queue[from_user]
):
self.user_msg_queue[from_user].pop(0)
response_xml = xml_template.format(
to_user=from_user,
from_user=to_user,
create_time=int(time.time()),
content=queue_content
content=queue_content,
)
return response_xml
@@ -248,65 +265,67 @@ class OAClientForLongerResponse():
to_user=from_user,
from_user=to_user,
create_time=int(time.time()),
content=self.loading_message
content=self.loading_message,
)
if self.user_msg_queue.get(from_user) and self.user_msg_queue[from_user][0]["content"]:
if (
self.user_msg_queue.get(from_user)
and self.user_msg_queue[from_user][0]['content']
):
return response_xml
else:
message_data = await self.get_message(xml_msg)
if message_data:
event = OAEvent.from_payload(message_data)
if event:
self.user_msg_queue.setdefault(from_user,[]).append(
self.user_msg_queue.setdefault(from_user, []).append(
{
"content":event.message,
'content': event.message,
}
)
await self._handle_message(event)
return response_xml
except Exception as e:
except Exception:
traceback.print_exc()
async def get_message(self, xml_msg: str):
root = ET.fromstring(xml_msg)
message_data = {
"ToUserName": root.find("ToUserName").text,
"FromUserName": root.find("FromUserName").text,
"CreateTime": int(root.find("CreateTime").text),
"MsgType": root.find("MsgType").text,
"Content": root.find("Content").text if root.find("Content") is not None else None,
"MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None,
'ToUserName': root.find('ToUserName').text,
'FromUserName': root.find('FromUserName').text,
'CreateTime': int(root.find('CreateTime').text),
'MsgType': root.find('MsgType').text,
'Content': root.find('Content').text
if root.find('Content') is not None
else None,
'MsgId': int(root.find('MsgId').text)
if root.find('MsgId') is not None
else None,
}
return message_data
async def run_task(self, host: str, port: int, *args, **kwargs):
"""
启动 Quart 应用。
"""
await self.app.run_task(host=host, port=port, *args, **kwargs)
def on_message(self, msg_type: str):
"""
注册消息类型处理器。
"""
def decorator(func: Callable[[OAEvent], None]):
if msg_type not in self._message_handlers:
self._message_handlers[msg_type] = []
self._message_handlers[msg_type].append(func)
return func
return decorator
async def _handle_message(self, event: OAEvent):
@@ -319,22 +338,13 @@ class OAClientForLongerResponse():
for handler in self._message_handlers[msg_type]:
await handler(event)
async def set_message(self,from_user:int,message_id:int,content:str):
if from_user not in self.msg_queue:
async def set_message(self, from_user: int, message_id: int, content: str):
if from_user not in self.msg_queue:
self.msg_queue[from_user] = []
self.msg_queue[from_user].append(
{
"msg_id":message_id,
"content":content,
'msg_id': message_id,
'content': content,
}
)

View File

@@ -9,7 +9,7 @@ class OAEvent(dict):
"""
@staticmethod
def from_payload(payload: Dict[str, Any]) -> Optional["OAEvent"]:
def from_payload(payload: Dict[str, Any]) -> Optional['OAEvent']:
"""
从微信公众号事件数据构造 `WecomEvent` 对象。
@@ -34,14 +34,14 @@ class OAEvent(dict):
Returns:
str: 事件类型。
"""
return self.get("MsgType", "")
return self.get('MsgType', '')
@property
def picurl(self) -> str:
"""
图片链接
"""
return self.get("PicUrl","")
return self.get('PicUrl', '')
@property
def detail_type(self) -> str:
@@ -53,8 +53,8 @@ class OAEvent(dict):
Returns:
str: 事件详细类型。
"""
if self.type == "event":
return self.get("Event", "")
if self.type == 'event':
return self.get('Event', '')
return self.type
@property
@@ -65,15 +65,14 @@ class OAEvent(dict):
Returns:
str: 事件名。
"""
return f"{self.type}.{self.detail_type}"
return f'{self.type}.{self.detail_type}'
@property
def user_id(self) -> Optional[str]:
"""
发送方账号
"""
return self.get("FromUserName")
return self.get('FromUserName')
@property
def receiver_id(self) -> Optional[str]:
@@ -83,7 +82,7 @@ class OAEvent(dict):
Returns:
Optional[str]: 接收者 ID。
"""
return self.get("ToUserName")
return self.get('ToUserName')
@property
def message_id(self) -> Optional[str]:
@@ -93,7 +92,7 @@ class OAEvent(dict):
Returns:
Optional[str]: 消息 ID。
"""
return self.get("MsgId")
return self.get('MsgId')
@property
def message(self) -> Optional[str]:
@@ -103,7 +102,7 @@ class OAEvent(dict):
Returns:
Optional[str]: 消息内容。
"""
return self.get("Content")
return self.get('Content')
@property
def media_id(self) -> Optional[str]:
@@ -113,7 +112,7 @@ class OAEvent(dict):
Returns:
Optional[str]: 媒体文件 ID。
"""
return self.get("MediaId")
return self.get('MediaId')
@property
def timestamp(self) -> Optional[int]:
@@ -123,7 +122,7 @@ class OAEvent(dict):
Returns:
Optional[int]: 时间戳。
"""
return self.get("CreateTime")
return self.get('CreateTime')
@property
def event_key(self) -> Optional[str]:
@@ -133,7 +132,7 @@ class OAEvent(dict):
Returns:
Optional[str]: 事件 Key。
"""
return self.get("EventKey")
return self.get('EventKey')
def __getattr__(self, key: str) -> Optional[Any]:
"""
@@ -164,4 +163,4 @@ class OAEvent(dict):
Returns:
str: 字符串表示。
"""
return f"<WecomEvent {super().__repr__()}>"
return f'<WecomEvent {super().__repr__()}>'

View File

@@ -1,24 +1,16 @@
import time
from quart import request
import base64
import binascii
import httpx
from quart import Quart
import xml.etree.ElementTree as ET
from typing import Callable, Dict, Any
from pkg.platform.types import events as platform_events, message as platform_message
import aiofiles
from pkg.platform.types import events as platform_events
from .qqofficialevent import QQOfficialEvent
import json
import hmac
import base64
import hashlib
import traceback
from cryptography.hazmat.primitives.asymmetric import ed25519
from .qqofficialevent import QQOfficialEvent
def handle_validation(body: dict, bot_secret: str):
# bot正确的secert是32位的此处仅为了适配演示demo
while len(bot_secret) < 32:
bot_secret = bot_secret * 2
@@ -36,29 +28,26 @@ def handle_validation(body: dict, bot_secret: str):
signature_hex = signature.hex()
response = {
"plain_token": body['d']['plain_token'],
"signature": signature_hex
}
response = {'plain_token': body['d']['plain_token'], 'signature': signature_hex}
return response
class QQOfficialClient:
def __init__(self, secret: str, token: str, app_id: str):
self.app = Quart(__name__)
self.app.add_url_rule(
"/callback/command",
"handle_callback",
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=["GET", "POST"],
methods=['GET', 'POST'],
)
self.secret = secret
self.token = token
self.app_id = app_id
self._message_handlers = {
}
self.base_url = "https://api.sgroup.qq.com"
self.access_token = ""
self._message_handlers = {}
self.base_url = 'https://api.sgroup.qq.com'
self.access_token = ''
self.access_token_expiry_time = None
async def check_access_token(self):
@@ -66,30 +55,29 @@ class QQOfficialClient:
if not self.access_token or await self.is_token_expired():
return False
return bool(self.access_token and self.access_token.strip())
async def get_access_token(self):
"""获取access_token"""
url = "https://bots.qq.com/app/getAppAccessToken"
url = 'https://bots.qq.com/app/getAppAccessToken'
async with httpx.AsyncClient() as client:
params = {
"appId":self.app_id,
"clientSecret":self.secret,
'appId': self.app_id,
'clientSecret': self.secret,
}
headers = {
"content-type":"application/json",
'content-type': 'application/json',
}
try:
response = await client.post(url,json=params,headers=headers)
response = await client.post(url, json=params, headers=headers)
if response.status_code == 200:
response_data = response.json()
access_token = response_data.get("access_token")
expires_in = int(response_data.get("expires_in",7200))
access_token = response_data.get('access_token')
expires_in = int(response_data.get('expires_in', 7200))
self.access_token_expiry_time = time.time() + expires_in - 60
if access_token:
self.access_token = access_token
except Exception as e:
raise Exception(f"获取access_token失败: {e}")
raise Exception(f'获取access_token失败: {e}')
async def handle_callback_request(self):
"""处理回调请求"""
@@ -98,27 +86,24 @@ class QQOfficialClient:
body = await request.get_data()
payload = json.loads(body)
# 验证是否为回调验证请求
if payload.get("op") == 13:
if payload.get('op') == 13:
# 生成签名
response = handle_validation(payload, self.secret)
return response
if payload.get("op") == 0:
message_data = await self.get_message(payload)
if message_data:
event = QQOfficialEvent.from_payload(message_data)
await self._handle_message(event)
return {"code": 0, "message": "success"}
if payload.get('op') == 0:
message_data = await self.get_message(payload)
if message_data:
event = QQOfficialEvent.from_payload(message_data)
await self._handle_message(event)
return {'code': 0, 'message': 'success'}
except Exception as e:
traceback.print_exc()
return {"error": str(e)}, 400
return {'error': str(e)}, 400
async def run_task(self, host: str, port: int, *args, **kwargs):
"""启动 Quart 应用"""
@@ -135,133 +120,140 @@ class QQOfficialClient:
return decorator
async def _handle_message(self, event:QQOfficialEvent):
async def _handle_message(self, event: QQOfficialEvent):
"""处理消息事件"""
msg_type = event.t
if msg_type in self._message_handlers:
for handler in self._message_handlers[msg_type]:
await handler(event)
async def get_message(self,msg:dict) -> Dict[str,Any]:
async def get_message(self, msg: dict) -> Dict[str, Any]:
"""获取消息"""
message_data = {
"t": msg.get("t",{}),
"user_openid": msg.get("d",{}).get("author",{}).get("user_openid",{}),
"timestamp": msg.get("d",{}).get("timestamp",{}),
"d_author_id": msg.get("d",{}).get("author",{}).get("id",{}),
"content": msg.get("d",{}).get("content",{}),
"d_id": msg.get("d",{}).get("id",{}),
"id": msg.get("id",{}),
"channel_id": msg.get("d",{}).get("channel_id",{}),
"username": msg.get("d",{}).get("author",{}).get("username",{}),
"guild_id": msg.get("d",{}).get("guild_id",{}),
"member_openid": msg.get("d",{}).get("author",{}).get("openid",{}),
"group_openid": msg.get("d",{}).get("group_openid",{})
't': msg.get('t', {}),
'user_openid': msg.get('d', {}).get('author', {}).get('user_openid', {}),
'timestamp': msg.get('d', {}).get('timestamp', {}),
'd_author_id': msg.get('d', {}).get('author', {}).get('id', {}),
'content': msg.get('d', {}).get('content', {}),
'd_id': msg.get('d', {}).get('id', {}),
'id': msg.get('id', {}),
'channel_id': msg.get('d', {}).get('channel_id', {}),
'username': msg.get('d', {}).get('author', {}).get('username', {}),
'guild_id': msg.get('d', {}).get('guild_id', {}),
'member_openid': msg.get('d', {}).get('author', {}).get('openid', {}),
'group_openid': msg.get('d', {}).get('group_openid', {}),
}
attachments = msg.get("d", {}).get("attachments", [])
image_attachments = [attachment['url'] for attachment in attachments if await self.is_image(attachment)]
image_attachments_type = [attachment['content_type'] for attachment in attachments if await self.is_image(attachment)]
attachments = msg.get('d', {}).get('attachments', [])
image_attachments = [
attachment['url']
for attachment in attachments
if await self.is_image(attachment)
]
image_attachments_type = [
attachment['content_type']
for attachment in attachments
if await self.is_image(attachment)
]
if image_attachments:
message_data["image_attachments"] = image_attachments[0]
message_data["content_type"] = image_attachments_type[0]
message_data['image_attachments'] = image_attachments[0]
message_data['content_type'] = image_attachments_type[0]
else:
message_data["image_attachments"] = None
return message_data
message_data['image_attachments'] = None
async def is_image(self,attachment:dict) -> bool:
return message_data
async def is_image(self, attachment: dict) -> bool:
"""判断是否为图片附件"""
content_type = attachment.get("content_type","")
return content_type.startswith("image/")
async def send_private_text_msg(self,user_openid:str,content:str,msg_id:str):
content_type = attachment.get('content_type', '')
return content_type.startswith('image/')
async def send_private_text_msg(self, user_openid: str, content: str, msg_id: str):
"""发送私聊消息"""
if not await self.check_access_token():
await self.get_access_token()
await self.get_access_token()
url = self.base_url + "/v2/users/" + user_openid + "/messages"
url = self.base_url + '/v2/users/' + user_openid + '/messages'
async with httpx.AsyncClient() as client:
headers = {
"Authorization": f"QQBot {self.access_token}",
"Content-Type": "application/json",
'Authorization': f'QQBot {self.access_token}',
'Content-Type': 'application/json',
}
data = {
"content": content,
"msg_type": 0,
"msg_id": msg_id,
'content': content,
'msg_type': 0,
'msg_id': msg_id,
}
response = await client.post(url,headers=headers,json=data)
response = await client.post(url, headers=headers, json=data)
if response.status_code == 200:
return
else:
raise ValueError(response)
async def send_group_text_msg(self,group_openid:str,content:str,msg_id:str):
async def send_group_text_msg(self, group_openid: str, content: str, msg_id: str):
"""发送群聊消息"""
if not await self.check_access_token():
await self.get_access_token()
url = self.base_url + "/v2/groups/" + group_openid + "/messages"
url = self.base_url + '/v2/groups/' + group_openid + '/messages'
async with httpx.AsyncClient() as client:
headers = {
"Authorization": f"QQBot {self.access_token}",
"Content-Type": "application/json",
'Authorization': f'QQBot {self.access_token}',
'Content-Type': 'application/json',
}
data = {
"content": content,
"msg_type": 0,
"msg_id": msg_id,
'content': content,
'msg_type': 0,
'msg_id': msg_id,
}
response = await client.post(url,headers=headers,json=data)
response = await client.post(url, headers=headers, json=data)
if response.status_code == 200:
return
else:
raise Exception(response.read().decode())
async def send_channle_group_text_msg(self,channel_id:str,content:str,msg_id:str):
async def send_channle_group_text_msg(
self, channel_id: str, content: str, msg_id: str
):
"""发送频道群聊消息"""
if not await self.check_access_token():
await self.get_access_token()
await self.get_access_token()
url = self.base_url + "/channels/" + channel_id + "/messages"
url = self.base_url + '/channels/' + channel_id + '/messages'
async with httpx.AsyncClient() as client:
headers = {
"Authorization": f"QQBot {self.access_token}",
"Content-Type": "application/json",
'Authorization': f'QQBot {self.access_token}',
'Content-Type': 'application/json',
}
params = {
"content": content,
"msg_type": 0,
"msg_id": msg_id,
'content': content,
'msg_type': 0,
'msg_id': msg_id,
}
response = await client.post(url,headers=headers,json=params)
response = await client.post(url, headers=headers, json=params)
if response.status_code == 200:
return True
else:
raise Exception(response)
async def send_channle_private_text_msg(self,guild_id:str,content:str,msg_id:str):
async def send_channle_private_text_msg(
self, guild_id: str, content: str, msg_id: str
):
"""发送频道私聊消息"""
if not await self.check_access_token():
await self.get_access_token()
await self.get_access_token()
url = self.base_url + "/dms/" + guild_id + "/messages"
url = self.base_url + '/dms/' + guild_id + '/messages'
async with httpx.AsyncClient() as client:
headers = {
"Authorization": f"QQBot {self.access_token}",
"Content-Type": "application/json",
'Authorization': f'QQBot {self.access_token}',
'Content-Type': 'application/json',
}
params = {
"content": content,
"msg_type": 0,
"msg_id": msg_id,
'content': content,
'msg_type': 0,
'msg_id': msg_id,
}
response = await client.post(url,headers=headers,json=params)
response = await client.post(url, headers=headers, json=params)
if response.status_code == 200:
return True
else:

View File

@@ -1,114 +1,112 @@
from typing import Dict, Any, Optional
class QQOfficialEvent(dict):
@staticmethod
def from_payload(payload: Dict[str, Any]) -> Optional["QQOfficialEvent"]:
def from_payload(payload: Dict[str, Any]) -> Optional['QQOfficialEvent']:
try:
event = QQOfficialEvent(payload)
return event
except KeyError:
return None
@property
def t(self) -> str:
"""
事件类型
"""
return self.get("t", "")
return self.get('t', '')
@property
def user_openid(self) -> str:
"""
用户openid
"""
return self.get("user_openid",{})
return self.get('user_openid', {})
@property
def timestamp(self) -> str:
"""
时间戳
"""
return self.get("timestamp",{})
return self.get('timestamp', {})
@property
def d_author_id(self) -> str:
"""
作者id
"""
return self.get("id",{})
return self.get('id', {})
@property
def content(self) -> str:
"""
内容
"""
return self.get("content",'')
return self.get('content', '')
@property
def d_id(self) -> str:
"""
d_id
"""
return self.get("d_id",{})
return self.get('d_id', {})
@property
def id(self) -> str:
"""
消息idmsg_id
"""
return self.get("id",{})
return self.get('id', {})
@property
def channel_id(self) -> str:
"""
频道id
"""
return self.get("channel_id",{})
return self.get('channel_id', {})
@property
def username(self) -> str:
"""
用户名
"""
return self.get("username",{})
return self.get('username', {})
@property
def guild_id(self) -> str:
"""
频道id
"""
return self.get("guild_id",{})
return self.get('guild_id', {})
@property
def member_openid(self) -> str:
"""
成员openid
"""
return self.get("openid",{})
return self.get('openid', {})
@property
def attachments(self) -> str:
"""
附件url
"""
url = self.get("image_attachments", "")
if url and not url.startswith("https://"):
url = "https://" + url
url = self.get('image_attachments', '')
if url and not url.startswith('https://'):
url = 'https://' + url
return url
@property
def group_openid(self) -> str:
"""
群组id
"""
return self.get("group_openid",{})
return self.get('group_openid', {})
@property
def content_type(self) -> str:
"""
文件类型
"""
return self.get("content_type","")
return self.get('content_type', '')

View File

@@ -1,10 +1,11 @@
#!/usr/bin/env python
# -*- encoding:utf-8 -*-
""" 对企业微信发送给企业后台的消息加解密示例代码.
"""对企业微信发送给企业后台的消息加解密示例代码.
@copyright: Copyright (c) 1998-2014 Tencent Inc.
"""
# ------------------------------------------------------------------------
import logging
import base64
@@ -49,7 +50,7 @@ class SHA1:
sortlist = [token, timestamp, nonce, encrypt]
sortlist.sort()
sha = hashlib.sha1()
sha.update("".join(sortlist).encode())
sha.update(''.join(sortlist).encode())
return ierror.WXBizMsgCrypt_OK, sha.hexdigest()
except Exception as e:
logger = logging.getLogger()
@@ -75,7 +76,7 @@ class XMLParse:
"""
try:
xml_tree = ET.fromstring(xmltext)
encrypt = xml_tree.find("Encrypt")
encrypt = xml_tree.find('Encrypt')
return ierror.WXBizMsgCrypt_OK, encrypt.text
except Exception as e:
logger = logging.getLogger()
@@ -100,13 +101,13 @@ class XMLParse:
return resp_xml
class PKCS7Encoder():
class PKCS7Encoder:
"""提供基于PKCS7算法的加解密接口"""
block_size = 32
def encode(self, text):
""" 对需要加密的明文进行填充补位
"""对需要加密的明文进行填充补位
@param text: 需要进行填充补位操作的明文
@return: 补齐明文字符串
"""
@@ -134,7 +135,6 @@ class Prpcrypt(object):
"""提供接收和推送给企业微信消息的加解密接口"""
def __init__(self, key):
# self.key = base64.b64decode(key+"=")
self.key = key
# 设置加解密模式为AES的CBC模式
@@ -147,7 +147,12 @@ class Prpcrypt(object):
"""
# 16位随机字符串添加到明文开头
text = text.encode()
text = self.get_random_str() + struct.pack("I", socket.htonl(len(text))) + text + receiveid.encode()
text = (
self.get_random_str()
+ struct.pack('I', socket.htonl(len(text)))
+ text
+ receiveid.encode()
)
# 使用自定义的填充方式对明文进行补位填充
pkcs7 = PKCS7Encoder()
@@ -183,9 +188,9 @@ class Prpcrypt(object):
# plain_text = pkcs7.encode(plain_text)
# 去除16位随机字符串
content = plain_text[16:-pad]
xml_len = socket.ntohl(struct.unpack("I", content[: 4])[0])
xml_content = content[4: xml_len + 4]
from_receiveid = content[xml_len + 4:]
xml_len = socket.ntohl(struct.unpack('I', content[:4])[0])
xml_content = content[4 : xml_len + 4]
from_receiveid = content[xml_len + 4 :]
except Exception as e:
logger = logging.getLogger()
logger.error(e)
@@ -196,7 +201,7 @@ class Prpcrypt(object):
return 0, xml_content
def get_random_str(self):
""" 随机生成16位字符串
"""随机生成16位字符串
@return: 16位字符串
"""
return str(random.randint(1000000000000000, 9999999999999999)).encode()
@@ -206,10 +211,10 @@ class WXBizMsgCrypt(object):
# 构造函数
def __init__(self, sToken, sEncodingAESKey, sReceiveId):
try:
self.key = base64.b64decode(sEncodingAESKey + "=")
self.key = base64.b64decode(sEncodingAESKey + '=')
assert len(self.key) == 32
except:
throw_exception("[error]: EncodingAESKey unvalid !", FormatException)
except Exception:
throw_exception('[error]: EncodingAESKey unvalid !', FormatException)
# return ierror.WXBizMsgCrypt_IllegalAesKey,None
self.m_sToken = sToken
self.m_sReceiveId = sReceiveId

View File

@@ -7,15 +7,22 @@ from quart import Quart
import xml.etree.ElementTree as ET
from typing import Callable, Dict, Any
from .wecomevent import WecomEvent
from pkg.platform.types import events as platform_events, message as platform_message
from pkg.platform.types import message as platform_message
import aiofiles
class WecomClient():
def __init__(self,corpid:str,secret:str,token:str,EncodingAESKey:str,contacts_secret:str):
class WecomClient:
def __init__(
self,
corpid: str,
secret: str,
token: str,
EncodingAESKey: str,
contacts_secret: str,
):
self.corpid = corpid
self.secret = secret
self.access_token_for_contacts =''
self.access_token_for_contacts = ''
self.token = token
self.aes = EncodingAESKey
self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin'
@@ -23,19 +30,26 @@ class WecomClient():
self.secret_for_contacts = contacts_secret
self.app = Quart(__name__)
self.wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid)
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'])
self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['GET', 'POST'],
)
self._message_handlers = {
"example":[],
'example': [],
}
#access——token操作
# access——token操作
async def check_access_token(self):
return bool(self.access_token and self.access_token.strip())
async def check_access_token_for_contacts(self):
return bool(self.access_token_for_contacts and self.access_token_for_contacts.strip())
return bool(
self.access_token_for_contacts and self.access_token_for_contacts.strip()
)
async def get_access_token(self,secret):
async def get_access_token(self, secret):
url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corpid}&corpsecret={secret}'
async with httpx.AsyncClient() as client:
response = await client.get(url)
@@ -43,146 +57,163 @@ class WecomClient():
if 'access_token' in data:
return data['access_token']
else:
raise Exception(f"未获取access token: {data}")
raise Exception(f'未获取access token: {data}')
async def get_users(self):
if not self.check_access_token_for_contacts():
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
self.access_token_for_contacts = await self.get_access_token(
self.secret_for_contacts
)
url = self.base_url+'/user/list_id?access_token='+self.access_token_for_contacts
url = (
self.base_url
+ '/user/list_id?access_token='
+ self.access_token_for_contacts
)
async with httpx.AsyncClient() as client:
params = {
"cursor":"",
"limit":10000,
'cursor': '',
'limit': 10000,
}
response = await client.post(url,json=params)
response = await client.post(url, json=params)
data = response.json()
if data['errcode'] == 0:
dept_users = data['dept_user']
userid = []
for user in dept_users:
userid.append(user["userid"])
userid.append(user['userid'])
return userid
else:
raise Exception("未获取用户")
async def send_to_all(self,content:str,agent_id:int):
if not self.check_access_token_for_contacts():
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
raise Exception('未获取用户')
url = self.base_url+'/message/send?access_token='+self.access_token_for_contacts
async def send_to_all(self, content: str, agent_id: int):
if not self.check_access_token_for_contacts():
self.access_token_for_contacts = await self.get_access_token(
self.secret_for_contacts
)
url = (
self.base_url
+ '/message/send?access_token='
+ self.access_token_for_contacts
)
user_ids = await self.get_users()
user_ids_string = "|".join(user_ids)
user_ids_string = '|'.join(user_ids)
async with httpx.AsyncClient() as client:
params = {
"touser" : user_ids_string,
"msgtype" : "text",
"agentid" : agent_id,
"text" : {
"content" : content,
},
"safe":0,
"enable_id_trans": 0,
"enable_duplicate_check": 0,
"duplicate_check_interval": 1800
'touser': user_ids_string,
'msgtype': 'text',
'agentid': agent_id,
'text': {
'content': content,
},
'safe': 0,
'enable_id_trans': 0,
'enable_duplicate_check': 0,
'duplicate_check_interval': 1800,
}
response = await client.post(url,json=params)
response = await client.post(url, json=params)
data = response.json()
if data['errcode'] != 0:
raise Exception("Failed to send message: "+str(data))
raise Exception('Failed to send message: ' + str(data))
async def send_image(self,user_id:str,agent_id:int,media_id:str):
async def send_image(self, user_id: str, agent_id: int, media_id: str):
if not await self.check_access_token():
self.access_token = await self.get_access_token(self.secret)
url = self.base_url+'/media/upload?access_token='+self.access_token
url = self.base_url + '/media/upload?access_token=' + self.access_token
async with httpx.AsyncClient() as client:
params = {
"touser" : user_id,
"toparty" : "",
"totag":"",
"agentid" : agent_id,
"msgtype" : "image",
"image" : {
"media_id" : media_id,
'touser': user_id,
'toparty': '',
'totag': '',
'agentid': agent_id,
'msgtype': 'image',
'image': {
'media_id': media_id,
},
"safe":0,
"enable_id_trans": 0,
"enable_duplicate_check": 0,
"duplicate_check_interval": 1800
'safe': 0,
'enable_id_trans': 0,
'enable_duplicate_check': 0,
'duplicate_check_interval': 1800,
}
try:
response = await client.post(url,json=params)
response = await client.post(url, json=params)
data = response.json()
except Exception as e:
raise Exception("Failed to send image: "+str(e))
raise Exception('Failed to send image: ' + str(e))
# 企业微信错误码40014和42001代表accesstoken问题
if data['errcode'] == 40014 or data['errcode'] == 42001:
self.access_token = await self.get_access_token(self.secret)
return await self.send_image(user_id,agent_id,media_id)
return await self.send_image(user_id, agent_id, media_id)
if data['errcode'] != 0:
raise Exception("Failed to send image: "+str(data))
async def send_private_msg(self,user_id:str, agent_id:int,content:str):
raise Exception('Failed to send image: ' + str(data))
async def send_private_msg(self, user_id: str, agent_id: int, content: str):
if not await self.check_access_token():
self.access_token = await self.get_access_token(self.secret)
url = self.base_url+'/message/send?access_token='+self.access_token
url = self.base_url + '/message/send?access_token=' + self.access_token
async with httpx.AsyncClient() as client:
params={
"touser" : user_id,
"msgtype" : "text",
"agentid" : agent_id,
"text" : {
"content" : content,
params = {
'touser': user_id,
'msgtype': 'text',
'agentid': agent_id,
'text': {
'content': content,
},
"safe":0,
"enable_id_trans": 0,
"enable_duplicate_check": 0,
"duplicate_check_interval": 1800
'safe': 0,
'enable_id_trans': 0,
'enable_duplicate_check': 0,
'duplicate_check_interval': 1800,
}
response = await client.post(url,json=params)
response = await client.post(url, json=params)
data = response.json()
if data['errcode'] == 40014 or data['errcode'] == 42001:
self.access_token = await self.get_access_token(self.secret)
return await self.send_private_msg(user_id,agent_id,content)
return await self.send_private_msg(user_id, agent_id, content)
if data['errcode'] != 0:
raise Exception("Failed to send message: "+str(data))
raise Exception('Failed to send message: ' + str(data))
async def handle_callback_request(self):
"""
处理回调请求,包括 GET 验证和 POST 消息接收。
"""
try:
msg_signature = request.args.get('msg_signature')
timestamp = request.args.get('timestamp')
nonce = request.args.get('nonce')
msg_signature = request.args.get("msg_signature")
timestamp = request.args.get("timestamp")
nonce = request.args.get("nonce")
if request.method == "GET":
echostr = request.args.get("echostr")
ret, reply_echo_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
if request.method == 'GET':
echostr = request.args.get('echostr')
ret, reply_echo_str = self.wxcpt.VerifyURL(
msg_signature, timestamp, nonce, echostr
)
if ret != 0:
raise Exception(f"验证失败,错误码: {ret}")
raise Exception(f'验证失败,错误码: {ret}')
return reply_echo_str
elif request.method == "POST":
elif request.method == 'POST':
encrypt_msg = await request.data
ret, xml_msg = self.wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
ret, xml_msg = self.wxcpt.DecryptMsg(
encrypt_msg, msg_signature, timestamp, nonce
)
if ret != 0:
raise Exception(f"消息解密失败,错误码: {ret}")
raise Exception(f'消息解密失败,错误码: {ret}')
# 解析消息并处理
message_data = await self.get_message(xml_msg)
if message_data:
event = WecomEvent.from_payload(message_data) # 转换为 WecomEvent 对象
event = WecomEvent.from_payload(
message_data
) # 转换为 WecomEvent 对象
if event:
await self._handle_message(event)
return "success"
return 'success'
except Exception as e:
return f"Error processing request: {str(e)}", 400
return f'Error processing request: {str(e)}', 400
async def run_task(self, host: str, port: int, *args, **kwargs):
"""
@@ -194,11 +225,13 @@ class WecomClient():
"""
注册消息类型处理器。
"""
def decorator(func: Callable[[WecomEvent], None]):
if msg_type not in self._message_handlers:
self._message_handlers[msg_type] = []
self._message_handlers[msg_type].append(func)
return func
return decorator
async def _handle_message(self, event: WecomEvent):
@@ -216,38 +249,47 @@ class WecomClient():
"""
root = ET.fromstring(xml_msg)
message_data = {
"ToUserName": root.find("ToUserName").text,
"FromUserName": root.find("FromUserName").text,
"CreateTime": int(root.find("CreateTime").text),
"MsgType": root.find("MsgType").text,
"Content": root.find("Content").text if root.find("Content") is not None else None,
"MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None,
"AgentID": int(root.find("AgentID").text) if root.find("AgentID") is not None else None,
'ToUserName': root.find('ToUserName').text,
'FromUserName': root.find('FromUserName').text,
'CreateTime': int(root.find('CreateTime').text),
'MsgType': root.find('MsgType').text,
'Content': root.find('Content').text
if root.find('Content') is not None
else None,
'MsgId': int(root.find('MsgId').text)
if root.find('MsgId') is not None
else None,
'AgentID': int(root.find('AgentID').text)
if root.find('AgentID') is not None
else None,
}
if message_data["MsgType"] == "image":
message_data["MediaId"] = root.find("MediaId").text if root.find("MediaId") is not None else None
message_data["PicUrl"] = root.find("PicUrl").text if root.find("PicUrl") is not None else None
if message_data['MsgType'] == 'image':
message_data['MediaId'] = (
root.find('MediaId').text if root.find('MediaId') is not None else None
)
message_data['PicUrl'] = (
root.find('PicUrl').text if root.find('PicUrl') is not None else None
)
return message_data
@staticmethod
async def get_image_type(image_bytes: bytes) -> str:
"""
通过图片的magic numbers判断图片类型
"""
magic_numbers = {
b'\xFF\xD8\xFF': 'jpg',
b'\x89\x50\x4E\x47': 'png',
b'\xff\xd8\xff': 'jpg',
b'\x89\x50\x4e\x47': 'png',
b'\x47\x49\x46': 'gif',
b'\x42\x4D': 'bmp',
b'\x00\x00\x01\x00': 'ico'
b'\x42\x4d': 'bmp',
b'\x00\x00\x01\x00': 'ico',
}
for magic, ext in magic_numbers.items():
if image_bytes.startswith(magic):
return ext
return 'jpg' # 默认返回jpg
async def upload_to_work(self, image: platform_message.Image):
"""
@@ -256,9 +298,14 @@ class WecomClient():
if not await self.check_access_token():
self.access_token = await self.get_access_token(self.secret)
url = self.base_url + '/media/upload?access_token=' + self.access_token + '&type=file'
url = (
self.base_url
+ '/media/upload?access_token='
+ self.access_token
+ '&type=file'
)
file_bytes = None
file_name = "uploaded_file.txt"
file_name = 'uploaded_file.txt'
# 获取文件的二进制数据
if image.path:
@@ -277,20 +324,22 @@ class WecomClient():
padded_base64 = base64_data + '=' * padding
file_bytes = base64.b64decode(padded_base64)
except binascii.Error as e:
raise ValueError(f"Invalid base64 string: {str(e)}")
raise ValueError(f'Invalid base64 string: {str(e)}')
else:
raise ValueError("image对象出错")
raise ValueError('image对象出错')
# 设置 multipart/form-data 格式的文件
boundary = "-------------------------acebdf13572468"
headers = {
'Content-Type': f'multipart/form-data; boundary={boundary}'
}
boundary = '-------------------------acebdf13572468'
headers = {'Content-Type': f'multipart/form-data; boundary={boundary}'}
body = (
f"--{boundary}\r\n"
f"Content-Disposition: form-data; name=\"media\"; filename=\"{file_name}\"; filelength={len(file_bytes)}\r\n"
f"Content-Type: application/octet-stream\r\n\r\n"
).encode('utf-8') + file_bytes + f"\r\n--{boundary}--\r\n".encode('utf-8')
(
f'--{boundary}\r\n'
f'Content-Disposition: form-data; name="media"; filename="{file_name}"; filelength={len(file_bytes)}\r\n'
f'Content-Type: application/octet-stream\r\n\r\n'
).encode('utf-8')
+ file_bytes
+ f'\r\n--{boundary}--\r\n'.encode('utf-8')
)
# 上传文件
async with httpx.AsyncClient() as client:
@@ -300,19 +349,18 @@ class WecomClient():
self.access_token = await self.get_access_token(self.secret)
media_id = await self.upload_to_work(image)
if data.get('errcode', 0) != 0:
raise Exception("failed to upload file")
raise Exception('failed to upload file')
media_id = data.get('media_id')
return media_id
async def download_image_to_bytes(self,url:str) -> bytes:
async def download_image_to_bytes(self, url: str) -> bytes:
async with httpx.AsyncClient() as client:
response = await client.get(url)
response.raise_for_status()
return response.content
#进行media_id的获取
# 进行media_id的获取
async def get_media_id(self, image: platform_message.Image):
media_id = await self.upload_to_work(image=image)
return media_id

View File

@@ -4,7 +4,7 @@
# Author: jonyqin
# Created Time: Thu 11 Sep 2014 01:53:58 PM CST
# File Name: ierror.py
# Description:定义错误码含义
# Description:定义错误码含义
#########################################################################
WXBizMsgCrypt_OK = 0
WXBizMsgCrypt_ValidateSignature_Error = -40001
@@ -17,4 +17,4 @@ WXBizMsgCrypt_DecryptAES_Error = -40007
WXBizMsgCrypt_IllegalBuffer = -40008
WXBizMsgCrypt_EncodeBase64_Error = -40009
WXBizMsgCrypt_DecodeBase64_Error = -40010
WXBizMsgCrypt_GenReturnXml_Error = -40011
WXBizMsgCrypt_GenReturnXml_Error = -40011

View File

@@ -9,7 +9,7 @@ class WecomEvent(dict):
"""
@staticmethod
def from_payload(payload: Dict[str, Any]) -> Optional["WecomEvent"]:
def from_payload(payload: Dict[str, Any]) -> Optional['WecomEvent']:
"""
从企业微信事件数据构造 `WecomEvent` 对象。
@@ -34,14 +34,14 @@ class WecomEvent(dict):
Returns:
str: 事件类型。
"""
return self.get("MsgType", "")
return self.get('MsgType', '')
@property
def picurl(self) -> str:
"""
图片链接
"""
return self.get("PicUrl")
return self.get('PicUrl')
@property
def detail_type(self) -> str:
@@ -53,8 +53,8 @@ class WecomEvent(dict):
Returns:
str: 事件详细类型。
"""
if self.type == "event":
return self.get("Event", "")
if self.type == 'event':
return self.get('Event', '')
return self.type
@property
@@ -65,7 +65,7 @@ class WecomEvent(dict):
Returns:
str: 事件名。
"""
return f"{self.type}.{self.detail_type}"
return f'{self.type}.{self.detail_type}'
@property
def user_id(self) -> Optional[str]:
@@ -75,8 +75,8 @@ class WecomEvent(dict):
Returns:
Optional[str]: 用户 ID。
"""
return self.get("FromUserName")
return self.get('FromUserName')
@property
def agent_id(self) -> Optional[int]:
"""
@@ -85,7 +85,7 @@ class WecomEvent(dict):
Returns:
Optional[int]: 机器人 ID。
"""
return self.get("AgentID")
return self.get('AgentID')
@property
def receiver_id(self) -> Optional[str]:
@@ -95,7 +95,7 @@ class WecomEvent(dict):
Returns:
Optional[str]: 接收者 ID。
"""
return self.get("ToUserName")
return self.get('ToUserName')
@property
def message_id(self) -> Optional[str]:
@@ -105,7 +105,7 @@ class WecomEvent(dict):
Returns:
Optional[str]: 消息 ID。
"""
return self.get("MsgId")
return self.get('MsgId')
@property
def message(self) -> Optional[str]:
@@ -115,7 +115,7 @@ class WecomEvent(dict):
Returns:
Optional[str]: 消息内容。
"""
return self.get("Content")
return self.get('Content')
@property
def media_id(self) -> Optional[str]:
@@ -125,7 +125,7 @@ class WecomEvent(dict):
Returns:
Optional[str]: 媒体文件 ID。
"""
return self.get("MediaId")
return self.get('MediaId')
@property
def timestamp(self) -> Optional[int]:
@@ -135,7 +135,7 @@ class WecomEvent(dict):
Returns:
Optional[int]: 时间戳。
"""
return self.get("CreateTime")
return self.get('CreateTime')
@property
def event_key(self) -> Optional[str]:
@@ -145,7 +145,7 @@ class WecomEvent(dict):
Returns:
Optional[str]: 事件 Key。
"""
return self.get("EventKey")
return self.get('EventKey')
def __getattr__(self, key: str) -> Optional[Any]:
"""
@@ -176,4 +176,4 @@ class WecomEvent(dict):
Returns:
str: 字符串表示。
"""
return f"<WecomEvent {super().__repr__()}>"
return f'<WecomEvent {super().__repr__()}>'