mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 11:29:39 +08:00
style: introduce ruff as linter and formatter (#1356)
* style: remove necessary imports * style: fix F841 * style: fix F401 * style: fix F811 * style: fix E402 * style: fix E721 * style: fix E722 * style: fix E722 * style: fix F541 * style: ruff format * style: all passed * style: add ruff in deps * style: more ignores in ruff.toml * style: add pre-commit
This commit is contained in:
committed by
GitHub
parent
09e70d70e9
commit
209f16af76
@@ -1,2 +1,4 @@
|
||||
from .v1 import client
|
||||
from .v1 import errors
|
||||
from .v1 import client as client
|
||||
from .v1 import errors as errors
|
||||
|
||||
__all__ = ['client', 'errors']
|
||||
|
||||
@@ -8,25 +8,33 @@ import json
|
||||
|
||||
class TestDifyClient:
|
||||
async def test_chat_messages(self):
|
||||
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL"))
|
||||
cln = client.AsyncDifyServiceClient(
|
||||
api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL')
|
||||
)
|
||||
|
||||
async for chunk in cln.chat_messages(inputs={}, query="调用工具查看现在几点?", user="test"):
|
||||
async for chunk in cln.chat_messages(
|
||||
inputs={}, query='调用工具查看现在几点?', user='test'
|
||||
):
|
||||
print(json.dumps(chunk, ensure_ascii=False, indent=4))
|
||||
|
||||
async def test_upload_file(self):
|
||||
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL"))
|
||||
cln = client.AsyncDifyServiceClient(
|
||||
api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL')
|
||||
)
|
||||
|
||||
file_bytes = open("img.png", "rb").read()
|
||||
file_bytes = open('img.png', 'rb').read()
|
||||
|
||||
print(type(file_bytes))
|
||||
|
||||
file = ("img2.png", file_bytes, "image/png")
|
||||
file = ('img2.png', file_bytes, 'image/png')
|
||||
|
||||
resp = await cln.upload_file(file=file, user="test")
|
||||
resp = await cln.upload_file(file=file, user='test')
|
||||
print(json.dumps(resp, ensure_ascii=False, indent=4))
|
||||
|
||||
async def test_workflow_run(self):
|
||||
cln = client.AsyncDifyServiceClient(api_key=os.getenv("DIFY_API_KEY"), base_url=os.getenv("DIFY_BASE_URL"))
|
||||
cln = client.AsyncDifyServiceClient(
|
||||
api_key=os.getenv('DIFY_API_KEY'), base_url=os.getenv('DIFY_BASE_URL')
|
||||
)
|
||||
|
||||
# resp = await cln.workflow_run(inputs={}, user="test")
|
||||
# # print(json.dumps(resp, ensure_ascii=False, indent=4))
|
||||
@@ -34,11 +42,12 @@ class TestDifyClient:
|
||||
chunks = []
|
||||
|
||||
ignored_events = ['text_chunk']
|
||||
async for chunk in cln.workflow_run(inputs={}, user="test"):
|
||||
async for chunk in cln.workflow_run(inputs={}, user='test'):
|
||||
if chunk['event'] in ignored_events:
|
||||
continue
|
||||
chunks.append(chunk)
|
||||
print(json.dumps(chunks, ensure_ascii=False, indent=4))
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(TestDifyClient().test_chat_messages())
|
||||
|
||||
@@ -12,11 +12,11 @@ class AsyncDifyServiceClient:
|
||||
|
||||
api_key: str
|
||||
base_url: str
|
||||
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str,
|
||||
base_url: str = "https://api.dify.ai/v1",
|
||||
base_url: str = 'https://api.dify.ai/v1',
|
||||
) -> None:
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
@@ -26,76 +26,81 @@ class AsyncDifyServiceClient:
|
||||
inputs: dict[str, typing.Any],
|
||||
query: str,
|
||||
user: str,
|
||||
response_mode: str = "streaming", # 当前不支持 blocking
|
||||
conversation_id: str = "",
|
||||
response_mode: str = 'streaming', # 当前不支持 blocking
|
||||
conversation_id: str = '',
|
||||
files: list[dict[str, typing.Any]] = [],
|
||||
timeout: float = 30.0,
|
||||
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
|
||||
"""发送消息"""
|
||||
if response_mode != "streaming":
|
||||
raise DifyAPIError("当前仅支持 streaming 模式")
|
||||
|
||||
if response_mode != 'streaming':
|
||||
raise DifyAPIError('当前仅支持 streaming 模式')
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
trust_env=True,
|
||||
timeout=timeout,
|
||||
) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
"/chat-messages",
|
||||
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
|
||||
'POST',
|
||||
'/chat-messages',
|
||||
headers={
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
json={
|
||||
"inputs": inputs,
|
||||
"query": query,
|
||||
"user": user,
|
||||
"response_mode": response_mode,
|
||||
"conversation_id": conversation_id,
|
||||
"files": files,
|
||||
'inputs': inputs,
|
||||
'query': query,
|
||||
'user': user,
|
||||
'response_mode': response_mode,
|
||||
'conversation_id': conversation_id,
|
||||
'files': files,
|
||||
},
|
||||
) as r:
|
||||
async for chunk in r.aiter_lines():
|
||||
if r.status_code != 200:
|
||||
raise DifyAPIError(f"{r.status_code} {chunk}")
|
||||
if chunk.strip() == "":
|
||||
raise DifyAPIError(f'{r.status_code} {chunk}')
|
||||
if chunk.strip() == '':
|
||||
continue
|
||||
if chunk.startswith("data:"):
|
||||
if chunk.startswith('data:'):
|
||||
yield json.loads(chunk[5:])
|
||||
|
||||
|
||||
async def workflow_run(
|
||||
self,
|
||||
inputs: dict[str, typing.Any],
|
||||
user: str,
|
||||
response_mode: str = "streaming", # 当前不支持 blocking
|
||||
response_mode: str = 'streaming', # 当前不支持 blocking
|
||||
files: list[dict[str, typing.Any]] = [],
|
||||
timeout: float = 30.0,
|
||||
) -> typing.AsyncGenerator[dict[str, typing.Any], None]:
|
||||
"""运行工作流"""
|
||||
if response_mode != "streaming":
|
||||
raise DifyAPIError("当前仅支持 streaming 模式")
|
||||
|
||||
if response_mode != 'streaming':
|
||||
raise DifyAPIError('当前仅支持 streaming 模式')
|
||||
|
||||
async with httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
trust_env=True,
|
||||
timeout=timeout,
|
||||
) as client:
|
||||
|
||||
async with client.stream(
|
||||
"POST",
|
||||
"/workflows/run",
|
||||
headers={"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"},
|
||||
'POST',
|
||||
'/workflows/run',
|
||||
headers={
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
json={
|
||||
"inputs": inputs,
|
||||
"user": user,
|
||||
"response_mode": response_mode,
|
||||
"files": files,
|
||||
'inputs': inputs,
|
||||
'user': user,
|
||||
'response_mode': response_mode,
|
||||
'files': files,
|
||||
},
|
||||
) as r:
|
||||
async for chunk in r.aiter_lines():
|
||||
if r.status_code != 200:
|
||||
raise DifyAPIError(f"{r.status_code} {chunk}")
|
||||
if chunk.strip() == "":
|
||||
raise DifyAPIError(f'{r.status_code} {chunk}')
|
||||
if chunk.strip() == '':
|
||||
continue
|
||||
if chunk.startswith("data:"):
|
||||
if chunk.startswith('data:'):
|
||||
yield json.loads(chunk[5:])
|
||||
|
||||
async def upload_file(
|
||||
@@ -112,15 +117,15 @@ class AsyncDifyServiceClient:
|
||||
) as client:
|
||||
# multipart/form-data
|
||||
response = await client.post(
|
||||
"/files/upload",
|
||||
headers={"Authorization": f"Bearer {self.api_key}"},
|
||||
'/files/upload',
|
||||
headers={'Authorization': f'Bearer {self.api_key}'},
|
||||
files={
|
||||
"file": file,
|
||||
"user": (None, user),
|
||||
'file': file,
|
||||
'user': (None, user),
|
||||
},
|
||||
)
|
||||
|
||||
if response.status_code != 201:
|
||||
raise DifyAPIError(f"{response.status_code} {response.text}")
|
||||
raise DifyAPIError(f'{response.status_code} {response.text}')
|
||||
|
||||
return response.json()
|
||||
|
||||
@@ -7,11 +7,11 @@ import os
|
||||
|
||||
class TestDifyClient:
|
||||
async def test_chat_messages(self):
|
||||
cln = client.DifyClient(api_key=os.getenv("DIFY_API_KEY"))
|
||||
cln = client.DifyClient(api_key=os.getenv('DIFY_API_KEY'))
|
||||
|
||||
resp = await cln.chat_messages(inputs={}, query="Who are you?", user_id="test")
|
||||
resp = await cln.chat_messages(inputs={}, query='Who are you?', user_id='test')
|
||||
print(resp)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if __name__ == '__main__':
|
||||
asyncio.run(TestDifyClient().test_chat_messages())
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import asyncio
|
||||
import json
|
||||
import dingtalk_stream
|
||||
from dingtalk_stream import AckMessage
|
||||
|
||||
|
||||
class EchoTextHandler(dingtalk_stream.ChatbotHandler):
|
||||
def __init__(self, client):
|
||||
self.msg_id = ''
|
||||
@@ -10,6 +10,7 @@ class EchoTextHandler(dingtalk_stream.ChatbotHandler):
|
||||
self.client = client # 用于更新 DingTalkClient 中的 incoming_message
|
||||
|
||||
"""处理钉钉消息"""
|
||||
|
||||
async def process(self, callback: dingtalk_stream.CallbackMessage):
|
||||
incoming_message = dingtalk_stream.ChatbotMessage.from_dict(callback.data)
|
||||
if incoming_message.message_id != self.msg_id:
|
||||
@@ -26,6 +27,8 @@ class EchoTextHandler(dingtalk_stream.ChatbotHandler):
|
||||
|
||||
return self.incoming_message
|
||||
|
||||
|
||||
async def get_dingtalk_client(client_id, client_secret):
|
||||
from api import DingTalkClient # 延迟导入,避免循环导入
|
||||
|
||||
return DingTalkClient(client_id, client_secret)
|
||||
|
||||
@@ -10,7 +10,9 @@ import traceback
|
||||
|
||||
|
||||
class DingTalkClient:
|
||||
def __init__(self, client_id: str, client_secret: str,robot_name:str,robot_code:str):
|
||||
def __init__(
|
||||
self, client_id: str, client_secret: str, robot_name: str, robot_code: str
|
||||
):
|
||||
"""初始化 WebSocket 连接并自动启动"""
|
||||
self.credential = dingtalk_stream.Credential(client_id, client_secret)
|
||||
self.client = dingtalk_stream.DingTalkStreamClient(self.credential)
|
||||
@@ -18,106 +20,91 @@ class DingTalkClient:
|
||||
self.secret = client_secret
|
||||
# 在 DingTalkClient 中传入自己作为参数,避免循环导入
|
||||
self.EchoTextHandler = EchoTextHandler(self)
|
||||
self.client.register_callback_handler(dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler)
|
||||
self.client.register_callback_handler(
|
||||
dingtalk_stream.chatbot.ChatbotMessage.TOPIC, self.EchoTextHandler
|
||||
)
|
||||
self._message_handlers = {
|
||||
"example":[],
|
||||
'example': [],
|
||||
}
|
||||
self.access_token = ''
|
||||
self.robot_name = robot_name
|
||||
self.robot_code = robot_code
|
||||
self.access_token_expiry_time = ''
|
||||
|
||||
|
||||
|
||||
async def get_access_token(self):
|
||||
url = "https://api.dingtalk.com/v1.0/oauth2/accessToken"
|
||||
headers = {
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
data = {
|
||||
"appKey": self.key,
|
||||
"appSecret": self.secret
|
||||
}
|
||||
url = 'https://api.dingtalk.com/v1.0/oauth2/accessToken'
|
||||
headers = {'Content-Type': 'application/json'}
|
||||
data = {'appKey': self.key, 'appSecret': self.secret}
|
||||
async with httpx.AsyncClient() as client:
|
||||
try:
|
||||
response = await client.post(url,json=data,headers=headers)
|
||||
response = await client.post(url, json=data, headers=headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
self.access_token = response_data.get("accessToken")
|
||||
expires_in = int(response_data.get("expireIn",7200))
|
||||
self.access_token = response_data.get('accessToken')
|
||||
expires_in = int(response_data.get('expireIn', 7200))
|
||||
self.access_token_expiry_time = time.time() + expires_in - 60
|
||||
except Exception as e:
|
||||
raise Exception(e)
|
||||
|
||||
|
||||
async def is_token_expired(self):
|
||||
"""检查token是否过期"""
|
||||
if self.access_token_expiry_time is None:
|
||||
return True
|
||||
return time.time() > self.access_token_expiry_time
|
||||
|
||||
|
||||
async def check_access_token(self):
|
||||
if not self.access_token or await self.is_token_expired():
|
||||
return False
|
||||
return bool(self.access_token and self.access_token.strip())
|
||||
|
||||
async def download_image(self,download_code:str):
|
||||
async def download_image(self, download_code: str):
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download'
|
||||
params = {
|
||||
"downloadCode":download_code,
|
||||
"robotCode":self.robot_code
|
||||
}
|
||||
headers ={
|
||||
"x-acs-dingtalk-access-token": self.access_token
|
||||
}
|
||||
params = {'downloadCode': download_code, 'robotCode': self.robot_code}
|
||||
headers = {'x-acs-dingtalk-access-token': self.access_token}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, headers=headers, json=params)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
download_url = result.get("downloadUrl")
|
||||
download_url = result.get('downloadUrl')
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code}, {response.text}")
|
||||
raise Exception(f'Error: {response.status_code}, {response.text}')
|
||||
|
||||
if download_url:
|
||||
return await self.download_url_to_base64(download_url)
|
||||
|
||||
async def download_url_to_base64(self,download_url):
|
||||
async def download_url_to_base64(self, download_url):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(download_url)
|
||||
|
||||
|
||||
if response.status_code == 200:
|
||||
|
||||
file_bytes = response.content
|
||||
base64_str = base64.b64encode(file_bytes).decode('utf-8') # 返回字符串格式
|
||||
base64_str = base64.b64encode(file_bytes).decode(
|
||||
'utf-8'
|
||||
) # 返回字符串格式
|
||||
return base64_str
|
||||
else:
|
||||
raise Exception("获取文件失败")
|
||||
|
||||
async def get_audio_url(self,download_code:str):
|
||||
raise Exception('获取文件失败')
|
||||
|
||||
async def get_audio_url(self, download_code: str):
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
url = 'https://api.dingtalk.com/v1.0/robot/messageFiles/download'
|
||||
params = {
|
||||
"downloadCode":download_code,
|
||||
"robotCode":self.robot_code
|
||||
}
|
||||
headers ={
|
||||
"x-acs-dingtalk-access-token": self.access_token
|
||||
}
|
||||
params = {'downloadCode': download_code, 'robotCode': self.robot_code}
|
||||
headers = {'x-acs-dingtalk-access-token': self.access_token}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url, headers=headers, json=params)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
download_url = result.get("downloadUrl")
|
||||
download_url = result.get('downloadUrl')
|
||||
if download_url:
|
||||
return await self.download_url_to_base64(download_url)
|
||||
else:
|
||||
raise Exception("获取音频失败")
|
||||
raise Exception('获取音频失败')
|
||||
else:
|
||||
raise Exception(f"Error: {response.status_code}, {response.text}")
|
||||
|
||||
raise Exception(f'Error: {response.status_code}, {response.text}')
|
||||
|
||||
async def update_incoming_message(self, message):
|
||||
"""异步更新 DingTalkClient 中的 incoming_message"""
|
||||
message_data = await self.get_message(message)
|
||||
@@ -125,24 +112,21 @@ class DingTalkClient:
|
||||
event = DingTalkEvent.from_payload(message_data)
|
||||
if event:
|
||||
await self._handle_message(event)
|
||||
|
||||
|
||||
async def send_message(self,content:str,incoming_message):
|
||||
self.EchoTextHandler.reply_text(content,incoming_message)
|
||||
|
||||
async def send_message(self, content: str, incoming_message):
|
||||
self.EchoTextHandler.reply_text(content, incoming_message)
|
||||
|
||||
async def get_incoming_message(self):
|
||||
"""获取收到的消息"""
|
||||
return await self.EchoTextHandler.get_incoming_message()
|
||||
|
||||
|
||||
|
||||
def on_message(self, msg_type: str):
|
||||
def decorator(func: Callable[[DingTalkEvent], None]):
|
||||
if msg_type not in self._message_handlers:
|
||||
self._message_handlers[msg_type] = []
|
||||
self._message_handlers[msg_type].append(func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def _handle_message(self, event: DingTalkEvent):
|
||||
@@ -154,40 +138,44 @@ class DingTalkClient:
|
||||
for handler in self._message_handlers[msg_type]:
|
||||
await handler(event)
|
||||
|
||||
|
||||
async def get_message(self,incoming_message:dingtalk_stream.chatbot.ChatbotMessage):
|
||||
async def get_message(
|
||||
self, incoming_message: dingtalk_stream.chatbot.ChatbotMessage
|
||||
):
|
||||
try:
|
||||
|
||||
# print(json.dumps(incoming_message.to_dict(), indent=4, ensure_ascii=False))
|
||||
message_data = {
|
||||
"IncomingMessage":incoming_message,
|
||||
'IncomingMessage': incoming_message,
|
||||
}
|
||||
if str(incoming_message.conversation_type) == '1':
|
||||
message_data["conversation_type"] = 'FriendMessage'
|
||||
message_data['conversation_type'] = 'FriendMessage'
|
||||
elif str(incoming_message.conversation_type) == '2':
|
||||
message_data["conversation_type"] = 'GroupMessage'
|
||||
message_data['conversation_type'] = 'GroupMessage'
|
||||
|
||||
|
||||
if incoming_message.message_type == 'richText':
|
||||
|
||||
data = incoming_message.rich_text_content.to_dict()
|
||||
for item in data['richText']:
|
||||
if 'text' in item:
|
||||
message_data["Content"] = item['text']
|
||||
message_data['Content'] = item['text']
|
||||
if incoming_message.get_image_list()[0]:
|
||||
message_data["Picture"] = await self.download_image(incoming_message.get_image_list()[0])
|
||||
message_data["Type"] = 'text'
|
||||
|
||||
message_data['Picture'] = await self.download_image(
|
||||
incoming_message.get_image_list()[0]
|
||||
)
|
||||
message_data['Type'] = 'text'
|
||||
|
||||
elif incoming_message.message_type == 'text':
|
||||
message_data['Content'] = incoming_message.get_text_list()[0]
|
||||
|
||||
message_data["Type"] = 'text'
|
||||
message_data['Type'] = 'text'
|
||||
elif incoming_message.message_type == 'picture':
|
||||
message_data['Picture'] = await self.download_image(incoming_message.get_image_list()[0])
|
||||
|
||||
message_data['Picture'] = await self.download_image(
|
||||
incoming_message.get_image_list()[0]
|
||||
)
|
||||
|
||||
message_data['Type'] = 'image'
|
||||
elif incoming_message.message_type == 'audio':
|
||||
message_data['Audio'] = await self.get_audio_url(incoming_message.to_dict()['content']['downloadCode'])
|
||||
message_data['Audio'] = await self.get_audio_url(
|
||||
incoming_message.to_dict()['content']['downloadCode']
|
||||
)
|
||||
|
||||
message_data['Type'] = 'audio'
|
||||
|
||||
@@ -196,56 +184,55 @@ class DingTalkClient:
|
||||
# print("message_data:", json.dumps(copy_message_data, indent=4, ensure_ascii=False))
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
return message_data
|
||||
|
||||
async def send_proactive_message_to_one(self,target_id:str,content:str):
|
||||
async def send_proactive_message_to_one(self, target_id: str, content: str):
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
|
||||
url = 'https://api.dingtalk.com/v1.0/robot/oToMessages/batchSend'
|
||||
|
||||
headers ={
|
||||
"x-acs-dingtalk-access-token":self.access_token,
|
||||
"Content-Type":"application/json",
|
||||
headers = {
|
||||
'x-acs-dingtalk-access-token': self.access_token,
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
data ={
|
||||
"robotCode":self.robot_code,
|
||||
"userIds":[target_id],
|
||||
"msgKey": "sampleText",
|
||||
"msgParam": json.dumps({"content":content}),
|
||||
data = {
|
||||
'robotCode': self.robot_code,
|
||||
'userIds': [target_id],
|
||||
'msgKey': 'sampleText',
|
||||
'msgParam': json.dumps({'content': content}),
|
||||
}
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url,headers=headers,json=data)
|
||||
await client.post(url, headers=headers, json=data)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def send_proactive_message_to_group(self,target_id:str,content:str):
|
||||
async def send_proactive_message_to_group(self, target_id: str, content: str):
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
|
||||
url = 'https://api.dingtalk.com/v1.0/robot/groupMessages/send'
|
||||
|
||||
headers ={
|
||||
"x-acs-dingtalk-access-token":self.access_token,
|
||||
"Content-Type":"application/json",
|
||||
headers = {
|
||||
'x-acs-dingtalk-access-token': self.access_token,
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
|
||||
data ={
|
||||
"robotCode":self.robot_code,
|
||||
"openConversationId":target_id,
|
||||
"msgKey": "sampleText",
|
||||
"msgParam": json.dumps({"content":content}),
|
||||
data = {
|
||||
'robotCode': self.robot_code,
|
||||
'openConversationId': target_id,
|
||||
'msgKey': 'sampleText',
|
||||
'msgParam': json.dumps({'content': content}),
|
||||
}
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(url,headers=headers,json=data)
|
||||
await client.post(url, headers=headers, json=data)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def start(self):
|
||||
"""启动 WebSocket 连接,监听消息"""
|
||||
await self.client.start()
|
||||
await self.client.start()
|
||||
|
||||
@@ -1,41 +1,39 @@
|
||||
from typing import Dict, Any, Optional
|
||||
import dingtalk_stream
|
||||
|
||||
|
||||
class DingTalkEvent(dict):
|
||||
@staticmethod
|
||||
def from_payload(payload: Dict[str, Any]) -> Optional["DingTalkEvent"]:
|
||||
def from_payload(payload: Dict[str, Any]) -> Optional['DingTalkEvent']:
|
||||
try:
|
||||
event = DingTalkEvent(payload)
|
||||
return event
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
|
||||
@property
|
||||
def content(self):
|
||||
return self.get("Content","")
|
||||
|
||||
@property
|
||||
def incoming_message(self) -> Optional["dingtalk_stream.chatbot.ChatbotMessage"]:
|
||||
return self.get("IncomingMessage")
|
||||
def content(self):
|
||||
return self.get('Content', '')
|
||||
|
||||
@property
|
||||
def incoming_message(self) -> Optional['dingtalk_stream.chatbot.ChatbotMessage']:
|
||||
return self.get('IncomingMessage')
|
||||
|
||||
@property
|
||||
def type(self):
|
||||
return self.get("Type","")
|
||||
|
||||
return self.get('Type', '')
|
||||
|
||||
@property
|
||||
def picture(self):
|
||||
return self.get("Picture","")
|
||||
|
||||
return self.get('Picture', '')
|
||||
|
||||
@property
|
||||
def audio(self):
|
||||
return self.get("Audio","")
|
||||
return self.get('Audio', '')
|
||||
|
||||
@property
|
||||
def conversation(self):
|
||||
return self.get("conversation_type","")
|
||||
|
||||
|
||||
return self.get('conversation_type', '')
|
||||
|
||||
def __getattr__(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
@@ -66,4 +64,4 @@ class DingTalkEvent(dict):
|
||||
Returns:
|
||||
str: 字符串表示。
|
||||
"""
|
||||
return f"<DingTalkEvent {super().__repr__()}>"
|
||||
return f'<DingTalkEvent {super().__repr__()}>'
|
||||
|
||||
@@ -1,20 +1,14 @@
|
||||
# 微信公众号的加解密算法与企业微信一样,所以直接使用企业微信的加解密算法文件
|
||||
from collections import deque
|
||||
import time
|
||||
import traceback
|
||||
from ..wecom_api.WXBizMsgCrypt3 import WXBizMsgCrypt
|
||||
import xml.etree.ElementTree as ET
|
||||
from quart import Quart,request
|
||||
from quart import Quart, request
|
||||
import hashlib
|
||||
from typing import Callable, Dict, Any
|
||||
from typing import Callable
|
||||
from .oaevent import OAEvent
|
||||
import httpx
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
import xml.etree.ElementTree as ET
|
||||
from pkg.platform.sources import officialaccount as oa
|
||||
|
||||
|
||||
|
||||
xml_template = """
|
||||
@@ -28,9 +22,8 @@ xml_template = """
|
||||
"""
|
||||
|
||||
|
||||
class OAClient():
|
||||
|
||||
def __init__(self,token:str,EncodingAESKey:str,AppID:str,Appsecret:str):
|
||||
class OAClient:
|
||||
def __init__(self, token: str, EncodingAESKey: str, AppID: str, Appsecret: str):
|
||||
self.token = token
|
||||
self.aes = EncodingAESKey
|
||||
self.appid = AppID
|
||||
@@ -38,121 +31,130 @@ class OAClient():
|
||||
self.base_url = 'https://api.weixin.qq.com'
|
||||
self.access_token = ''
|
||||
self.app = Quart(__name__)
|
||||
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'])
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
self._message_handlers = {
|
||||
"example":[],
|
||||
'example': [],
|
||||
}
|
||||
self.access_token_expiry_time = None
|
||||
self.msg_id_map = {}
|
||||
self.generated_content = {}
|
||||
|
||||
async def handle_callback_request(self):
|
||||
|
||||
try:
|
||||
# 每隔100毫秒查询是否生成ai回答
|
||||
start_time = time.time()
|
||||
signature = request.args.get("signature", "")
|
||||
timestamp = request.args.get("timestamp", "")
|
||||
nonce = request.args.get("nonce", "")
|
||||
echostr = request.args.get("echostr", "")
|
||||
msg_signature = request.args.get("msg_signature","")
|
||||
signature = request.args.get('signature', '')
|
||||
timestamp = request.args.get('timestamp', '')
|
||||
nonce = request.args.get('nonce', '')
|
||||
echostr = request.args.get('echostr', '')
|
||||
msg_signature = request.args.get('msg_signature', '')
|
||||
if msg_signature is None:
|
||||
raise Exception("msg_signature不在请求体中")
|
||||
raise Exception('msg_signature不在请求体中')
|
||||
|
||||
if request.method == 'GET':
|
||||
# 校验签名
|
||||
check_str = "".join(sorted([self.token, timestamp, nonce]))
|
||||
check_signature = hashlib.sha1(check_str.encode("utf-8")).hexdigest()
|
||||
|
||||
check_str = ''.join(sorted([self.token, timestamp, nonce]))
|
||||
check_signature = hashlib.sha1(check_str.encode('utf-8')).hexdigest()
|
||||
|
||||
if check_signature == signature:
|
||||
return echostr # 验证成功返回echostr
|
||||
else:
|
||||
raise Exception("拒绝请求")
|
||||
elif request.method == "POST":
|
||||
raise Exception('拒绝请求')
|
||||
elif request.method == 'POST':
|
||||
encryt_msg = await request.data
|
||||
wxcpt = WXBizMsgCrypt(self.token,self.aes,self.appid)
|
||||
ret,xml_msg = wxcpt.DecryptMsg(encryt_msg,msg_signature,timestamp,nonce)
|
||||
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
|
||||
ret, xml_msg = wxcpt.DecryptMsg(
|
||||
encryt_msg, msg_signature, timestamp, nonce
|
||||
)
|
||||
xml_msg = xml_msg.decode('utf-8')
|
||||
|
||||
if ret != 0:
|
||||
raise Exception("消息解密失败")
|
||||
raise Exception('消息解密失败')
|
||||
|
||||
message_data = await self.get_message(xml_msg)
|
||||
if message_data :
|
||||
if message_data:
|
||||
event = OAEvent.from_payload(message_data)
|
||||
if event:
|
||||
await self._handle_message(event)
|
||||
|
||||
root = ET.fromstring(xml_msg)
|
||||
from_user = root.find("FromUserName").text # 发送者
|
||||
to_user = root.find("ToUserName").text # 机器人
|
||||
|
||||
from_user = root.find('FromUserName').text # 发送者
|
||||
to_user = root.find('ToUserName').text # 机器人
|
||||
|
||||
timeout = 4.80
|
||||
interval = 0.1
|
||||
while True:
|
||||
content = self.generated_content.pop(message_data["MsgId"], None)
|
||||
content = self.generated_content.pop(message_data['MsgId'], None)
|
||||
if content:
|
||||
response_xml = xml_template.format(
|
||||
to_user=from_user,
|
||||
from_user=to_user,
|
||||
create_time=int(time.time()),
|
||||
content = content
|
||||
content=content,
|
||||
)
|
||||
|
||||
return response_xml
|
||||
|
||||
|
||||
if time.time() - start_time >= timeout:
|
||||
break
|
||||
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
if self.msg_id_map.get(message_data["MsgId"], 1) == 3:
|
||||
|
||||
if self.msg_id_map.get(message_data['MsgId'], 1) == 3:
|
||||
# response_xml = xml_template.format(
|
||||
# to_user=from_user,
|
||||
# from_user=to_user,
|
||||
# create_time=int(time.time()),
|
||||
# content = "请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。"
|
||||
# )
|
||||
print("请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。")
|
||||
print(
|
||||
'请求失效:暂不支持公众号超过15秒的请求,如有需求,请联系 LangBot 团队。'
|
||||
)
|
||||
return ''
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
async def get_message(self, xml_msg: str):
|
||||
|
||||
root = ET.fromstring(xml_msg)
|
||||
|
||||
message_data = {
|
||||
"ToUserName": root.find("ToUserName").text,
|
||||
"FromUserName": root.find("FromUserName").text,
|
||||
"CreateTime": int(root.find("CreateTime").text),
|
||||
"MsgType": root.find("MsgType").text,
|
||||
"Content": root.find("Content").text if root.find("Content") is not None else None,
|
||||
"MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None,
|
||||
'ToUserName': root.find('ToUserName').text,
|
||||
'FromUserName': root.find('FromUserName').text,
|
||||
'CreateTime': int(root.find('CreateTime').text),
|
||||
'MsgType': root.find('MsgType').text,
|
||||
'Content': root.find('Content').text
|
||||
if root.find('Content') is not None
|
||||
else None,
|
||||
'MsgId': int(root.find('MsgId').text)
|
||||
if root.find('MsgId') is not None
|
||||
else None,
|
||||
}
|
||||
|
||||
return message_data
|
||||
|
||||
|
||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||
"""
|
||||
启动 Quart 应用。
|
||||
"""
|
||||
await self.app.run_task(host=host, port=port, *args, **kwargs)
|
||||
|
||||
|
||||
def on_message(self, msg_type: str):
|
||||
"""
|
||||
注册消息类型处理器。
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[[OAEvent], None]):
|
||||
if msg_type not in self._message_handlers:
|
||||
self._message_handlers[msg_type] = []
|
||||
self._message_handlers[msg_type].append(func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def _handle_message(self, event: OAEvent):
|
||||
@@ -170,14 +172,19 @@ class OAClient():
|
||||
for handler in self._message_handlers[msg_type]:
|
||||
await handler(event)
|
||||
|
||||
async def set_message(self,msg_id:int,content:str):
|
||||
async def set_message(self, msg_id: int, content: str):
|
||||
self.generated_content[msg_id] = content
|
||||
|
||||
|
||||
|
||||
class OAClientForLongerResponse():
|
||||
|
||||
def __init__(self,token:str,EncodingAESKey:str,AppID:str,Appsecret:str,LoadingMessage:str):
|
||||
class OAClientForLongerResponse:
|
||||
def __init__(
|
||||
self,
|
||||
token: str,
|
||||
EncodingAESKey: str,
|
||||
AppID: str,
|
||||
Appsecret: str,
|
||||
LoadingMessage: str,
|
||||
):
|
||||
self.token = token
|
||||
self.aes = EncodingAESKey
|
||||
self.appid = AppID
|
||||
@@ -185,9 +192,14 @@ class OAClientForLongerResponse():
|
||||
self.base_url = 'https://api.weixin.qq.com'
|
||||
self.access_token = ''
|
||||
self.app = Quart(__name__)
|
||||
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'])
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
self._message_handlers = {
|
||||
"example":[],
|
||||
'example': [],
|
||||
}
|
||||
self.access_token_expiry_time = None
|
||||
self.loading_message = LoadingMessage
|
||||
@@ -196,50 +208,55 @@ class OAClientForLongerResponse():
|
||||
|
||||
async def handle_callback_request(self):
|
||||
try:
|
||||
start_time = time.time()
|
||||
signature = request.args.get("signature", "")
|
||||
timestamp = request.args.get("timestamp", "")
|
||||
nonce = request.args.get("nonce", "")
|
||||
echostr = request.args.get("echostr", "")
|
||||
msg_signature = request.args.get("msg_signature", "")
|
||||
signature = request.args.get('signature', '')
|
||||
timestamp = request.args.get('timestamp', '')
|
||||
nonce = request.args.get('nonce', '')
|
||||
echostr = request.args.get('echostr', '')
|
||||
msg_signature = request.args.get('msg_signature', '')
|
||||
|
||||
if msg_signature is None:
|
||||
raise Exception("msg_signature不在请求体中")
|
||||
raise Exception('msg_signature不在请求体中')
|
||||
|
||||
if request.method == 'GET':
|
||||
check_str = "".join(sorted([self.token, timestamp, nonce]))
|
||||
check_signature = hashlib.sha1(check_str.encode("utf-8")).hexdigest()
|
||||
return echostr if check_signature == signature else "拒绝请求"
|
||||
check_str = ''.join(sorted([self.token, timestamp, nonce]))
|
||||
check_signature = hashlib.sha1(check_str.encode('utf-8')).hexdigest()
|
||||
return echostr if check_signature == signature else '拒绝请求'
|
||||
|
||||
elif request.method == "POST":
|
||||
elif request.method == 'POST':
|
||||
encryt_msg = await request.data
|
||||
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
|
||||
ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce)
|
||||
ret, xml_msg = wxcpt.DecryptMsg(
|
||||
encryt_msg, msg_signature, timestamp, nonce
|
||||
)
|
||||
xml_msg = xml_msg.decode('utf-8')
|
||||
|
||||
if ret != 0:
|
||||
raise Exception("消息解密失败")
|
||||
raise Exception('消息解密失败')
|
||||
|
||||
# 解析 XML
|
||||
root = ET.fromstring(xml_msg)
|
||||
from_user = root.find("FromUserName").text
|
||||
to_user = root.find("ToUserName").text
|
||||
|
||||
|
||||
if self.msg_queue.get(from_user) and self.msg_queue[from_user][0]["content"]:
|
||||
from_user = root.find('FromUserName').text
|
||||
to_user = root.find('ToUserName').text
|
||||
|
||||
if (
|
||||
self.msg_queue.get(from_user)
|
||||
and self.msg_queue[from_user][0]['content']
|
||||
):
|
||||
queue_top = self.msg_queue[from_user].pop(0)
|
||||
queue_content = queue_top["content"]
|
||||
queue_content = queue_top['content']
|
||||
|
||||
# 弹出用户消息
|
||||
if self.user_msg_queue.get(from_user) and self.user_msg_queue[from_user]:
|
||||
if (
|
||||
self.user_msg_queue.get(from_user)
|
||||
and self.user_msg_queue[from_user]
|
||||
):
|
||||
self.user_msg_queue[from_user].pop(0)
|
||||
|
||||
response_xml = xml_template.format(
|
||||
to_user=from_user,
|
||||
from_user=to_user,
|
||||
create_time=int(time.time()),
|
||||
content=queue_content
|
||||
content=queue_content,
|
||||
)
|
||||
return response_xml
|
||||
|
||||
@@ -248,65 +265,67 @@ class OAClientForLongerResponse():
|
||||
to_user=from_user,
|
||||
from_user=to_user,
|
||||
create_time=int(time.time()),
|
||||
content=self.loading_message
|
||||
content=self.loading_message,
|
||||
)
|
||||
|
||||
if self.user_msg_queue.get(from_user) and self.user_msg_queue[from_user][0]["content"]:
|
||||
|
||||
if (
|
||||
self.user_msg_queue.get(from_user)
|
||||
and self.user_msg_queue[from_user][0]['content']
|
||||
):
|
||||
return response_xml
|
||||
else:
|
||||
message_data = await self.get_message(xml_msg)
|
||||
|
||||
|
||||
if message_data:
|
||||
event = OAEvent.from_payload(message_data)
|
||||
if event:
|
||||
self.user_msg_queue.setdefault(from_user,[]).append(
|
||||
self.user_msg_queue.setdefault(from_user, []).append(
|
||||
{
|
||||
"content":event.message,
|
||||
'content': event.message,
|
||||
}
|
||||
)
|
||||
await self._handle_message(event)
|
||||
|
||||
return response_xml
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
|
||||
|
||||
async def get_message(self, xml_msg: str):
|
||||
|
||||
root = ET.fromstring(xml_msg)
|
||||
|
||||
message_data = {
|
||||
"ToUserName": root.find("ToUserName").text,
|
||||
"FromUserName": root.find("FromUserName").text,
|
||||
"CreateTime": int(root.find("CreateTime").text),
|
||||
"MsgType": root.find("MsgType").text,
|
||||
"Content": root.find("Content").text if root.find("Content") is not None else None,
|
||||
"MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None,
|
||||
'ToUserName': root.find('ToUserName').text,
|
||||
'FromUserName': root.find('FromUserName').text,
|
||||
'CreateTime': int(root.find('CreateTime').text),
|
||||
'MsgType': root.find('MsgType').text,
|
||||
'Content': root.find('Content').text
|
||||
if root.find('Content') is not None
|
||||
else None,
|
||||
'MsgId': int(root.find('MsgId').text)
|
||||
if root.find('MsgId') is not None
|
||||
else None,
|
||||
}
|
||||
|
||||
return message_data
|
||||
|
||||
|
||||
|
||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||
"""
|
||||
启动 Quart 应用。
|
||||
"""
|
||||
await self.app.run_task(host=host, port=port, *args, **kwargs)
|
||||
|
||||
|
||||
|
||||
def on_message(self, msg_type: str):
|
||||
"""
|
||||
注册消息类型处理器。
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[[OAEvent], None]):
|
||||
if msg_type not in self._message_handlers:
|
||||
self._message_handlers[msg_type] = []
|
||||
self._message_handlers[msg_type].append(func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def _handle_message(self, event: OAEvent):
|
||||
@@ -319,22 +338,13 @@ class OAClientForLongerResponse():
|
||||
for handler in self._message_handlers[msg_type]:
|
||||
await handler(event)
|
||||
|
||||
async def set_message(self,from_user:int,message_id:int,content:str):
|
||||
if from_user not in self.msg_queue:
|
||||
async def set_message(self, from_user: int, message_id: int, content: str):
|
||||
if from_user not in self.msg_queue:
|
||||
self.msg_queue[from_user] = []
|
||||
|
||||
|
||||
self.msg_queue[from_user].append(
|
||||
{
|
||||
"msg_id":message_id,
|
||||
"content":content,
|
||||
'msg_id': message_id,
|
||||
'content': content,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ class OAEvent(dict):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_payload(payload: Dict[str, Any]) -> Optional["OAEvent"]:
|
||||
def from_payload(payload: Dict[str, Any]) -> Optional['OAEvent']:
|
||||
"""
|
||||
从微信公众号事件数据构造 `WecomEvent` 对象。
|
||||
|
||||
@@ -34,14 +34,14 @@ class OAEvent(dict):
|
||||
Returns:
|
||||
str: 事件类型。
|
||||
"""
|
||||
return self.get("MsgType", "")
|
||||
|
||||
return self.get('MsgType', '')
|
||||
|
||||
@property
|
||||
def picurl(self) -> str:
|
||||
"""
|
||||
图片链接
|
||||
"""
|
||||
return self.get("PicUrl","")
|
||||
return self.get('PicUrl', '')
|
||||
|
||||
@property
|
||||
def detail_type(self) -> str:
|
||||
@@ -53,8 +53,8 @@ class OAEvent(dict):
|
||||
Returns:
|
||||
str: 事件详细类型。
|
||||
"""
|
||||
if self.type == "event":
|
||||
return self.get("Event", "")
|
||||
if self.type == 'event':
|
||||
return self.get('Event', '')
|
||||
return self.type
|
||||
|
||||
@property
|
||||
@@ -65,15 +65,14 @@ class OAEvent(dict):
|
||||
Returns:
|
||||
str: 事件名。
|
||||
"""
|
||||
return f"{self.type}.{self.detail_type}"
|
||||
return f'{self.type}.{self.detail_type}'
|
||||
|
||||
@property
|
||||
def user_id(self) -> Optional[str]:
|
||||
"""
|
||||
发送方账号
|
||||
"""
|
||||
return self.get("FromUserName")
|
||||
|
||||
return self.get('FromUserName')
|
||||
|
||||
@property
|
||||
def receiver_id(self) -> Optional[str]:
|
||||
@@ -83,7 +82,7 @@ class OAEvent(dict):
|
||||
Returns:
|
||||
Optional[str]: 接收者 ID。
|
||||
"""
|
||||
return self.get("ToUserName")
|
||||
return self.get('ToUserName')
|
||||
|
||||
@property
|
||||
def message_id(self) -> Optional[str]:
|
||||
@@ -93,7 +92,7 @@ class OAEvent(dict):
|
||||
Returns:
|
||||
Optional[str]: 消息 ID。
|
||||
"""
|
||||
return self.get("MsgId")
|
||||
return self.get('MsgId')
|
||||
|
||||
@property
|
||||
def message(self) -> Optional[str]:
|
||||
@@ -103,7 +102,7 @@ class OAEvent(dict):
|
||||
Returns:
|
||||
Optional[str]: 消息内容。
|
||||
"""
|
||||
return self.get("Content")
|
||||
return self.get('Content')
|
||||
|
||||
@property
|
||||
def media_id(self) -> Optional[str]:
|
||||
@@ -113,7 +112,7 @@ class OAEvent(dict):
|
||||
Returns:
|
||||
Optional[str]: 媒体文件 ID。
|
||||
"""
|
||||
return self.get("MediaId")
|
||||
return self.get('MediaId')
|
||||
|
||||
@property
|
||||
def timestamp(self) -> Optional[int]:
|
||||
@@ -123,7 +122,7 @@ class OAEvent(dict):
|
||||
Returns:
|
||||
Optional[int]: 时间戳。
|
||||
"""
|
||||
return self.get("CreateTime")
|
||||
return self.get('CreateTime')
|
||||
|
||||
@property
|
||||
def event_key(self) -> Optional[str]:
|
||||
@@ -133,7 +132,7 @@ class OAEvent(dict):
|
||||
Returns:
|
||||
Optional[str]: 事件 Key。
|
||||
"""
|
||||
return self.get("EventKey")
|
||||
return self.get('EventKey')
|
||||
|
||||
def __getattr__(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
@@ -164,4 +163,4 @@ class OAEvent(dict):
|
||||
Returns:
|
||||
str: 字符串表示。
|
||||
"""
|
||||
return f"<WecomEvent {super().__repr__()}>"
|
||||
return f'<WecomEvent {super().__repr__()}>'
|
||||
|
||||
@@ -1,24 +1,16 @@
|
||||
import time
|
||||
from quart import request
|
||||
import base64
|
||||
import binascii
|
||||
import httpx
|
||||
from quart import Quart
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Callable, Dict, Any
|
||||
from pkg.platform.types import events as platform_events, message as platform_message
|
||||
import aiofiles
|
||||
from pkg.platform.types import events as platform_events
|
||||
from .qqofficialevent import QQOfficialEvent
|
||||
import json
|
||||
import hmac
|
||||
import base64
|
||||
import hashlib
|
||||
import traceback
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
from .qqofficialevent import QQOfficialEvent
|
||||
|
||||
|
||||
def handle_validation(body: dict, bot_secret: str):
|
||||
|
||||
# bot正确的secert是32位的,此处仅为了适配演示demo
|
||||
while len(bot_secret) < 32:
|
||||
bot_secret = bot_secret * 2
|
||||
@@ -36,29 +28,26 @@ def handle_validation(body: dict, bot_secret: str):
|
||||
|
||||
signature_hex = signature.hex()
|
||||
|
||||
response = {
|
||||
"plain_token": body['d']['plain_token'],
|
||||
"signature": signature_hex
|
||||
}
|
||||
response = {'plain_token': body['d']['plain_token'], 'signature': signature_hex}
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class QQOfficialClient:
|
||||
def __init__(self, secret: str, token: str, app_id: str):
|
||||
self.app = Quart(__name__)
|
||||
self.app.add_url_rule(
|
||||
"/callback/command",
|
||||
"handle_callback",
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=["GET", "POST"],
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
self.secret = secret
|
||||
self.token = token
|
||||
self.app_id = app_id
|
||||
self._message_handlers = {
|
||||
}
|
||||
self.base_url = "https://api.sgroup.qq.com"
|
||||
self.access_token = ""
|
||||
self._message_handlers = {}
|
||||
self.base_url = 'https://api.sgroup.qq.com'
|
||||
self.access_token = ''
|
||||
self.access_token_expiry_time = None
|
||||
|
||||
async def check_access_token(self):
|
||||
@@ -66,30 +55,29 @@ class QQOfficialClient:
|
||||
if not self.access_token or await self.is_token_expired():
|
||||
return False
|
||||
return bool(self.access_token and self.access_token.strip())
|
||||
|
||||
|
||||
async def get_access_token(self):
|
||||
"""获取access_token"""
|
||||
url = "https://bots.qq.com/app/getAppAccessToken"
|
||||
url = 'https://bots.qq.com/app/getAppAccessToken'
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = {
|
||||
"appId":self.app_id,
|
||||
"clientSecret":self.secret,
|
||||
'appId': self.app_id,
|
||||
'clientSecret': self.secret,
|
||||
}
|
||||
headers = {
|
||||
"content-type":"application/json",
|
||||
'content-type': 'application/json',
|
||||
}
|
||||
try:
|
||||
response = await client.post(url,json=params,headers=headers)
|
||||
response = await client.post(url, json=params, headers=headers)
|
||||
if response.status_code == 200:
|
||||
response_data = response.json()
|
||||
access_token = response_data.get("access_token")
|
||||
expires_in = int(response_data.get("expires_in",7200))
|
||||
access_token = response_data.get('access_token')
|
||||
expires_in = int(response_data.get('expires_in', 7200))
|
||||
self.access_token_expiry_time = time.time() + expires_in - 60
|
||||
if access_token:
|
||||
self.access_token = access_token
|
||||
except Exception as e:
|
||||
raise Exception(f"获取access_token失败: {e}")
|
||||
|
||||
raise Exception(f'获取access_token失败: {e}')
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""处理回调请求"""
|
||||
@@ -98,27 +86,24 @@ class QQOfficialClient:
|
||||
body = await request.get_data()
|
||||
payload = json.loads(body)
|
||||
|
||||
|
||||
# 验证是否为回调验证请求
|
||||
if payload.get("op") == 13:
|
||||
if payload.get('op') == 13:
|
||||
# 生成签名
|
||||
response = handle_validation(payload, self.secret)
|
||||
|
||||
return response
|
||||
|
||||
if payload.get("op") == 0:
|
||||
message_data = await self.get_message(payload)
|
||||
if message_data:
|
||||
event = QQOfficialEvent.from_payload(message_data)
|
||||
await self._handle_message(event)
|
||||
|
||||
return {"code": 0, "message": "success"}
|
||||
if payload.get('op') == 0:
|
||||
message_data = await self.get_message(payload)
|
||||
if message_data:
|
||||
event = QQOfficialEvent.from_payload(message_data)
|
||||
await self._handle_message(event)
|
||||
|
||||
return {'code': 0, 'message': 'success'}
|
||||
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
return {"error": str(e)}, 400
|
||||
|
||||
|
||||
return {'error': str(e)}, 400
|
||||
|
||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||
"""启动 Quart 应用"""
|
||||
@@ -135,133 +120,140 @@ class QQOfficialClient:
|
||||
|
||||
return decorator
|
||||
|
||||
async def _handle_message(self, event:QQOfficialEvent):
|
||||
async def _handle_message(self, event: QQOfficialEvent):
|
||||
"""处理消息事件"""
|
||||
msg_type = event.t
|
||||
if msg_type in self._message_handlers:
|
||||
for handler in self._message_handlers[msg_type]:
|
||||
await handler(event)
|
||||
|
||||
|
||||
async def get_message(self,msg:dict) -> Dict[str,Any]:
|
||||
async def get_message(self, msg: dict) -> Dict[str, Any]:
|
||||
"""获取消息"""
|
||||
message_data = {
|
||||
"t": msg.get("t",{}),
|
||||
"user_openid": msg.get("d",{}).get("author",{}).get("user_openid",{}),
|
||||
"timestamp": msg.get("d",{}).get("timestamp",{}),
|
||||
"d_author_id": msg.get("d",{}).get("author",{}).get("id",{}),
|
||||
"content": msg.get("d",{}).get("content",{}),
|
||||
"d_id": msg.get("d",{}).get("id",{}),
|
||||
"id": msg.get("id",{}),
|
||||
"channel_id": msg.get("d",{}).get("channel_id",{}),
|
||||
"username": msg.get("d",{}).get("author",{}).get("username",{}),
|
||||
"guild_id": msg.get("d",{}).get("guild_id",{}),
|
||||
"member_openid": msg.get("d",{}).get("author",{}).get("openid",{}),
|
||||
"group_openid": msg.get("d",{}).get("group_openid",{})
|
||||
't': msg.get('t', {}),
|
||||
'user_openid': msg.get('d', {}).get('author', {}).get('user_openid', {}),
|
||||
'timestamp': msg.get('d', {}).get('timestamp', {}),
|
||||
'd_author_id': msg.get('d', {}).get('author', {}).get('id', {}),
|
||||
'content': msg.get('d', {}).get('content', {}),
|
||||
'd_id': msg.get('d', {}).get('id', {}),
|
||||
'id': msg.get('id', {}),
|
||||
'channel_id': msg.get('d', {}).get('channel_id', {}),
|
||||
'username': msg.get('d', {}).get('author', {}).get('username', {}),
|
||||
'guild_id': msg.get('d', {}).get('guild_id', {}),
|
||||
'member_openid': msg.get('d', {}).get('author', {}).get('openid', {}),
|
||||
'group_openid': msg.get('d', {}).get('group_openid', {}),
|
||||
}
|
||||
attachments = msg.get("d", {}).get("attachments", [])
|
||||
image_attachments = [attachment['url'] for attachment in attachments if await self.is_image(attachment)]
|
||||
image_attachments_type = [attachment['content_type'] for attachment in attachments if await self.is_image(attachment)]
|
||||
attachments = msg.get('d', {}).get('attachments', [])
|
||||
image_attachments = [
|
||||
attachment['url']
|
||||
for attachment in attachments
|
||||
if await self.is_image(attachment)
|
||||
]
|
||||
image_attachments_type = [
|
||||
attachment['content_type']
|
||||
for attachment in attachments
|
||||
if await self.is_image(attachment)
|
||||
]
|
||||
if image_attachments:
|
||||
message_data["image_attachments"] = image_attachments[0]
|
||||
message_data["content_type"] = image_attachments_type[0]
|
||||
message_data['image_attachments'] = image_attachments[0]
|
||||
message_data['content_type'] = image_attachments_type[0]
|
||||
else:
|
||||
|
||||
message_data["image_attachments"] = None
|
||||
|
||||
return message_data
|
||||
|
||||
message_data['image_attachments'] = None
|
||||
|
||||
async def is_image(self,attachment:dict) -> bool:
|
||||
return message_data
|
||||
|
||||
async def is_image(self, attachment: dict) -> bool:
|
||||
"""判断是否为图片附件"""
|
||||
content_type = attachment.get("content_type","")
|
||||
return content_type.startswith("image/")
|
||||
|
||||
|
||||
async def send_private_text_msg(self,user_openid:str,content:str,msg_id:str):
|
||||
content_type = attachment.get('content_type', '')
|
||||
return content_type.startswith('image/')
|
||||
|
||||
async def send_private_text_msg(self, user_openid: str, content: str, msg_id: str):
|
||||
"""发送私聊消息"""
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
await self.get_access_token()
|
||||
|
||||
url = self.base_url + "/v2/users/" + user_openid + "/messages"
|
||||
url = self.base_url + '/v2/users/' + user_openid + '/messages'
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
"Authorization": f"QQBot {self.access_token}",
|
||||
"Content-Type": "application/json",
|
||||
'Authorization': f'QQBot {self.access_token}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
data = {
|
||||
"content": content,
|
||||
"msg_type": 0,
|
||||
"msg_id": msg_id,
|
||||
'content': content,
|
||||
'msg_type': 0,
|
||||
'msg_id': msg_id,
|
||||
}
|
||||
response = await client.post(url,headers=headers,json=data)
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
if response.status_code == 200:
|
||||
return
|
||||
else:
|
||||
raise ValueError(response)
|
||||
|
||||
|
||||
async def send_group_text_msg(self,group_openid:str,content:str,msg_id:str):
|
||||
async def send_group_text_msg(self, group_openid: str, content: str, msg_id: str):
|
||||
"""发送群聊消息"""
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
|
||||
url = self.base_url + "/v2/groups/" + group_openid + "/messages"
|
||||
url = self.base_url + '/v2/groups/' + group_openid + '/messages'
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
"Authorization": f"QQBot {self.access_token}",
|
||||
"Content-Type": "application/json",
|
||||
'Authorization': f'QQBot {self.access_token}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
data = {
|
||||
"content": content,
|
||||
"msg_type": 0,
|
||||
"msg_id": msg_id,
|
||||
'content': content,
|
||||
'msg_type': 0,
|
||||
'msg_id': msg_id,
|
||||
}
|
||||
response = await client.post(url,headers=headers,json=data)
|
||||
response = await client.post(url, headers=headers, json=data)
|
||||
if response.status_code == 200:
|
||||
return
|
||||
else:
|
||||
raise Exception(response.read().decode())
|
||||
|
||||
async def send_channle_group_text_msg(self,channel_id:str,content:str,msg_id:str):
|
||||
async def send_channle_group_text_msg(
|
||||
self, channel_id: str, content: str, msg_id: str
|
||||
):
|
||||
"""发送频道群聊消息"""
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
await self.get_access_token()
|
||||
|
||||
url = self.base_url + "/channels/" + channel_id + "/messages"
|
||||
url = self.base_url + '/channels/' + channel_id + '/messages'
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
"Authorization": f"QQBot {self.access_token}",
|
||||
"Content-Type": "application/json",
|
||||
'Authorization': f'QQBot {self.access_token}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
params = {
|
||||
"content": content,
|
||||
"msg_type": 0,
|
||||
"msg_id": msg_id,
|
||||
'content': content,
|
||||
'msg_type': 0,
|
||||
'msg_id': msg_id,
|
||||
}
|
||||
response = await client.post(url,headers=headers,json=params)
|
||||
response = await client.post(url, headers=headers, json=params)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
raise Exception(response)
|
||||
|
||||
async def send_channle_private_text_msg(self,guild_id:str,content:str,msg_id:str):
|
||||
async def send_channle_private_text_msg(
|
||||
self, guild_id: str, content: str, msg_id: str
|
||||
):
|
||||
"""发送频道私聊消息"""
|
||||
if not await self.check_access_token():
|
||||
await self.get_access_token()
|
||||
await self.get_access_token()
|
||||
|
||||
url = self.base_url + "/dms/" + guild_id + "/messages"
|
||||
url = self.base_url + '/dms/' + guild_id + '/messages'
|
||||
async with httpx.AsyncClient() as client:
|
||||
headers = {
|
||||
"Authorization": f"QQBot {self.access_token}",
|
||||
"Content-Type": "application/json",
|
||||
'Authorization': f'QQBot {self.access_token}',
|
||||
'Content-Type': 'application/json',
|
||||
}
|
||||
params = {
|
||||
"content": content,
|
||||
"msg_type": 0,
|
||||
"msg_id": msg_id,
|
||||
'content': content,
|
||||
'msg_type': 0,
|
||||
'msg_id': msg_id,
|
||||
}
|
||||
response = await client.post(url,headers=headers,json=params)
|
||||
response = await client.post(url, headers=headers, json=params)
|
||||
if response.status_code == 200:
|
||||
return True
|
||||
else:
|
||||
|
||||
@@ -1,114 +1,112 @@
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
|
||||
class QQOfficialEvent(dict):
|
||||
@staticmethod
|
||||
def from_payload(payload: Dict[str, Any]) -> Optional["QQOfficialEvent"]:
|
||||
def from_payload(payload: Dict[str, Any]) -> Optional['QQOfficialEvent']:
|
||||
try:
|
||||
event = QQOfficialEvent(payload)
|
||||
return event
|
||||
except KeyError:
|
||||
return None
|
||||
|
||||
|
||||
@property
|
||||
def t(self) -> str:
|
||||
"""
|
||||
事件类型
|
||||
"""
|
||||
return self.get("t", "")
|
||||
|
||||
return self.get('t', '')
|
||||
|
||||
@property
|
||||
def user_openid(self) -> str:
|
||||
"""
|
||||
用户openid
|
||||
"""
|
||||
return self.get("user_openid",{})
|
||||
|
||||
return self.get('user_openid', {})
|
||||
|
||||
@property
|
||||
def timestamp(self) -> str:
|
||||
"""
|
||||
时间戳
|
||||
"""
|
||||
return self.get("timestamp",{})
|
||||
|
||||
|
||||
return self.get('timestamp', {})
|
||||
|
||||
@property
|
||||
def d_author_id(self) -> str:
|
||||
"""
|
||||
作者id
|
||||
"""
|
||||
return self.get("id",{})
|
||||
|
||||
return self.get('id', {})
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
"""
|
||||
内容
|
||||
"""
|
||||
return self.get("content",'')
|
||||
|
||||
return self.get('content', '')
|
||||
|
||||
@property
|
||||
def d_id(self) -> str:
|
||||
"""
|
||||
d_id
|
||||
"""
|
||||
return self.get("d_id",{})
|
||||
|
||||
return self.get('d_id', {})
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
"""
|
||||
消息id,msg_id
|
||||
"""
|
||||
return self.get("id",{})
|
||||
|
||||
return self.get('id', {})
|
||||
|
||||
@property
|
||||
def channel_id(self) -> str:
|
||||
"""
|
||||
频道id
|
||||
"""
|
||||
return self.get("channel_id",{})
|
||||
|
||||
return self.get('channel_id', {})
|
||||
|
||||
@property
|
||||
def username(self) -> str:
|
||||
"""
|
||||
用户名
|
||||
"""
|
||||
return self.get("username",{})
|
||||
|
||||
return self.get('username', {})
|
||||
|
||||
@property
|
||||
def guild_id(self) -> str:
|
||||
"""
|
||||
频道id
|
||||
"""
|
||||
return self.get("guild_id",{})
|
||||
|
||||
return self.get('guild_id', {})
|
||||
|
||||
@property
|
||||
def member_openid(self) -> str:
|
||||
"""
|
||||
成员openid
|
||||
"""
|
||||
return self.get("openid",{})
|
||||
|
||||
return self.get('openid', {})
|
||||
|
||||
@property
|
||||
def attachments(self) -> str:
|
||||
"""
|
||||
附件url
|
||||
"""
|
||||
url = self.get("image_attachments", "")
|
||||
if url and not url.startswith("https://"):
|
||||
url = "https://" + url
|
||||
url = self.get('image_attachments', '')
|
||||
if url and not url.startswith('https://'):
|
||||
url = 'https://' + url
|
||||
return url
|
||||
|
||||
|
||||
@property
|
||||
def group_openid(self) -> str:
|
||||
"""
|
||||
群组id
|
||||
"""
|
||||
return self.get("group_openid",{})
|
||||
|
||||
return self.get('group_openid', {})
|
||||
|
||||
@property
|
||||
def content_type(self) -> str:
|
||||
"""
|
||||
文件类型
|
||||
"""
|
||||
return self.get("content_type","")
|
||||
|
||||
return self.get('content_type', '')
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
#!/usr/bin/env python
|
||||
# -*- encoding:utf-8 -*-
|
||||
|
||||
""" 对企业微信发送给企业后台的消息加解密示例代码.
|
||||
"""对企业微信发送给企业后台的消息加解密示例代码.
|
||||
@copyright: Copyright (c) 1998-2014 Tencent Inc.
|
||||
|
||||
"""
|
||||
|
||||
# ------------------------------------------------------------------------
|
||||
import logging
|
||||
import base64
|
||||
@@ -49,7 +50,7 @@ class SHA1:
|
||||
sortlist = [token, timestamp, nonce, encrypt]
|
||||
sortlist.sort()
|
||||
sha = hashlib.sha1()
|
||||
sha.update("".join(sortlist).encode())
|
||||
sha.update(''.join(sortlist).encode())
|
||||
return ierror.WXBizMsgCrypt_OK, sha.hexdigest()
|
||||
except Exception as e:
|
||||
logger = logging.getLogger()
|
||||
@@ -75,7 +76,7 @@ class XMLParse:
|
||||
"""
|
||||
try:
|
||||
xml_tree = ET.fromstring(xmltext)
|
||||
encrypt = xml_tree.find("Encrypt")
|
||||
encrypt = xml_tree.find('Encrypt')
|
||||
return ierror.WXBizMsgCrypt_OK, encrypt.text
|
||||
except Exception as e:
|
||||
logger = logging.getLogger()
|
||||
@@ -100,13 +101,13 @@ class XMLParse:
|
||||
return resp_xml
|
||||
|
||||
|
||||
class PKCS7Encoder():
|
||||
class PKCS7Encoder:
|
||||
"""提供基于PKCS7算法的加解密接口"""
|
||||
|
||||
block_size = 32
|
||||
|
||||
def encode(self, text):
|
||||
""" 对需要加密的明文进行填充补位
|
||||
"""对需要加密的明文进行填充补位
|
||||
@param text: 需要进行填充补位操作的明文
|
||||
@return: 补齐明文字符串
|
||||
"""
|
||||
@@ -134,7 +135,6 @@ class Prpcrypt(object):
|
||||
"""提供接收和推送给企业微信消息的加解密接口"""
|
||||
|
||||
def __init__(self, key):
|
||||
|
||||
# self.key = base64.b64decode(key+"=")
|
||||
self.key = key
|
||||
# 设置加解密模式为AES的CBC模式
|
||||
@@ -147,7 +147,12 @@ class Prpcrypt(object):
|
||||
"""
|
||||
# 16位随机字符串添加到明文开头
|
||||
text = text.encode()
|
||||
text = self.get_random_str() + struct.pack("I", socket.htonl(len(text))) + text + receiveid.encode()
|
||||
text = (
|
||||
self.get_random_str()
|
||||
+ struct.pack('I', socket.htonl(len(text)))
|
||||
+ text
|
||||
+ receiveid.encode()
|
||||
)
|
||||
|
||||
# 使用自定义的填充方式对明文进行补位填充
|
||||
pkcs7 = PKCS7Encoder()
|
||||
@@ -183,9 +188,9 @@ class Prpcrypt(object):
|
||||
# plain_text = pkcs7.encode(plain_text)
|
||||
# 去除16位随机字符串
|
||||
content = plain_text[16:-pad]
|
||||
xml_len = socket.ntohl(struct.unpack("I", content[: 4])[0])
|
||||
xml_content = content[4: xml_len + 4]
|
||||
from_receiveid = content[xml_len + 4:]
|
||||
xml_len = socket.ntohl(struct.unpack('I', content[:4])[0])
|
||||
xml_content = content[4 : xml_len + 4]
|
||||
from_receiveid = content[xml_len + 4 :]
|
||||
except Exception as e:
|
||||
logger = logging.getLogger()
|
||||
logger.error(e)
|
||||
@@ -196,7 +201,7 @@ class Prpcrypt(object):
|
||||
return 0, xml_content
|
||||
|
||||
def get_random_str(self):
|
||||
""" 随机生成16位字符串
|
||||
"""随机生成16位字符串
|
||||
@return: 16位字符串
|
||||
"""
|
||||
return str(random.randint(1000000000000000, 9999999999999999)).encode()
|
||||
@@ -206,10 +211,10 @@ class WXBizMsgCrypt(object):
|
||||
# 构造函数
|
||||
def __init__(self, sToken, sEncodingAESKey, sReceiveId):
|
||||
try:
|
||||
self.key = base64.b64decode(sEncodingAESKey + "=")
|
||||
self.key = base64.b64decode(sEncodingAESKey + '=')
|
||||
assert len(self.key) == 32
|
||||
except:
|
||||
throw_exception("[error]: EncodingAESKey unvalid !", FormatException)
|
||||
except Exception:
|
||||
throw_exception('[error]: EncodingAESKey unvalid !', FormatException)
|
||||
# return ierror.WXBizMsgCrypt_IllegalAesKey,None
|
||||
self.m_sToken = sToken
|
||||
self.m_sReceiveId = sReceiveId
|
||||
|
||||
@@ -7,15 +7,22 @@ from quart import Quart
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import Callable, Dict, Any
|
||||
from .wecomevent import WecomEvent
|
||||
from pkg.platform.types import events as platform_events, message as platform_message
|
||||
from pkg.platform.types import message as platform_message
|
||||
import aiofiles
|
||||
|
||||
|
||||
class WecomClient():
|
||||
def __init__(self,corpid:str,secret:str,token:str,EncodingAESKey:str,contacts_secret:str):
|
||||
class WecomClient:
|
||||
def __init__(
|
||||
self,
|
||||
corpid: str,
|
||||
secret: str,
|
||||
token: str,
|
||||
EncodingAESKey: str,
|
||||
contacts_secret: str,
|
||||
):
|
||||
self.corpid = corpid
|
||||
self.secret = secret
|
||||
self.access_token_for_contacts =''
|
||||
self.access_token_for_contacts = ''
|
||||
self.token = token
|
||||
self.aes = EncodingAESKey
|
||||
self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin'
|
||||
@@ -23,19 +30,26 @@ class WecomClient():
|
||||
self.secret_for_contacts = contacts_secret
|
||||
self.app = Quart(__name__)
|
||||
self.wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid)
|
||||
self.app.add_url_rule('/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST'])
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
self._message_handlers = {
|
||||
"example":[],
|
||||
'example': [],
|
||||
}
|
||||
|
||||
#access——token操作
|
||||
# access——token操作
|
||||
async def check_access_token(self):
|
||||
return bool(self.access_token and self.access_token.strip())
|
||||
|
||||
async def check_access_token_for_contacts(self):
|
||||
return bool(self.access_token_for_contacts and self.access_token_for_contacts.strip())
|
||||
return bool(
|
||||
self.access_token_for_contacts and self.access_token_for_contacts.strip()
|
||||
)
|
||||
|
||||
async def get_access_token(self,secret):
|
||||
async def get_access_token(self, secret):
|
||||
url = f'https://qyapi.weixin.qq.com/cgi-bin/gettoken?corpid={self.corpid}&corpsecret={secret}'
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
@@ -43,146 +57,163 @@ class WecomClient():
|
||||
if 'access_token' in data:
|
||||
return data['access_token']
|
||||
else:
|
||||
raise Exception(f"未获取access token: {data}")
|
||||
raise Exception(f'未获取access token: {data}')
|
||||
|
||||
async def get_users(self):
|
||||
if not self.check_access_token_for_contacts():
|
||||
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
|
||||
self.access_token_for_contacts = await self.get_access_token(
|
||||
self.secret_for_contacts
|
||||
)
|
||||
|
||||
url = self.base_url+'/user/list_id?access_token='+self.access_token_for_contacts
|
||||
url = (
|
||||
self.base_url
|
||||
+ '/user/list_id?access_token='
|
||||
+ self.access_token_for_contacts
|
||||
)
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = {
|
||||
"cursor":"",
|
||||
"limit":10000,
|
||||
'cursor': '',
|
||||
'limit': 10000,
|
||||
}
|
||||
response = await client.post(url,json=params)
|
||||
response = await client.post(url, json=params)
|
||||
data = response.json()
|
||||
if data['errcode'] == 0:
|
||||
dept_users = data['dept_user']
|
||||
userid = []
|
||||
for user in dept_users:
|
||||
userid.append(user["userid"])
|
||||
userid.append(user['userid'])
|
||||
return userid
|
||||
else:
|
||||
raise Exception("未获取用户")
|
||||
|
||||
async def send_to_all(self,content:str,agent_id:int):
|
||||
if not self.check_access_token_for_contacts():
|
||||
self.access_token_for_contacts = await self.get_access_token(self.secret_for_contacts)
|
||||
raise Exception('未获取用户')
|
||||
|
||||
url = self.base_url+'/message/send?access_token='+self.access_token_for_contacts
|
||||
async def send_to_all(self, content: str, agent_id: int):
|
||||
if not self.check_access_token_for_contacts():
|
||||
self.access_token_for_contacts = await self.get_access_token(
|
||||
self.secret_for_contacts
|
||||
)
|
||||
|
||||
url = (
|
||||
self.base_url
|
||||
+ '/message/send?access_token='
|
||||
+ self.access_token_for_contacts
|
||||
)
|
||||
user_ids = await self.get_users()
|
||||
user_ids_string = "|".join(user_ids)
|
||||
user_ids_string = '|'.join(user_ids)
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = {
|
||||
"touser" : user_ids_string,
|
||||
"msgtype" : "text",
|
||||
"agentid" : agent_id,
|
||||
"text" : {
|
||||
"content" : content,
|
||||
},
|
||||
"safe":0,
|
||||
"enable_id_trans": 0,
|
||||
"enable_duplicate_check": 0,
|
||||
"duplicate_check_interval": 1800
|
||||
'touser': user_ids_string,
|
||||
'msgtype': 'text',
|
||||
'agentid': agent_id,
|
||||
'text': {
|
||||
'content': content,
|
||||
},
|
||||
'safe': 0,
|
||||
'enable_id_trans': 0,
|
||||
'enable_duplicate_check': 0,
|
||||
'duplicate_check_interval': 1800,
|
||||
}
|
||||
response = await client.post(url,json=params)
|
||||
response = await client.post(url, json=params)
|
||||
data = response.json()
|
||||
if data['errcode'] != 0:
|
||||
raise Exception("Failed to send message: "+str(data))
|
||||
raise Exception('Failed to send message: ' + str(data))
|
||||
|
||||
async def send_image(self,user_id:str,agent_id:int,media_id:str):
|
||||
async def send_image(self, user_id: str, agent_id: int, media_id: str):
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
url = self.base_url+'/media/upload?access_token='+self.access_token
|
||||
url = self.base_url + '/media/upload?access_token=' + self.access_token
|
||||
async with httpx.AsyncClient() as client:
|
||||
params = {
|
||||
"touser" : user_id,
|
||||
"toparty" : "",
|
||||
"totag":"",
|
||||
"agentid" : agent_id,
|
||||
"msgtype" : "image",
|
||||
"image" : {
|
||||
"media_id" : media_id,
|
||||
'touser': user_id,
|
||||
'toparty': '',
|
||||
'totag': '',
|
||||
'agentid': agent_id,
|
||||
'msgtype': 'image',
|
||||
'image': {
|
||||
'media_id': media_id,
|
||||
},
|
||||
"safe":0,
|
||||
"enable_id_trans": 0,
|
||||
"enable_duplicate_check": 0,
|
||||
"duplicate_check_interval": 1800
|
||||
'safe': 0,
|
||||
'enable_id_trans': 0,
|
||||
'enable_duplicate_check': 0,
|
||||
'duplicate_check_interval': 1800,
|
||||
}
|
||||
try:
|
||||
response = await client.post(url,json=params)
|
||||
response = await client.post(url, json=params)
|
||||
data = response.json()
|
||||
except Exception as e:
|
||||
raise Exception("Failed to send image: "+str(e))
|
||||
raise Exception('Failed to send image: ' + str(e))
|
||||
|
||||
# 企业微信错误码40014和42001,代表accesstoken问题
|
||||
if data['errcode'] == 40014 or data['errcode'] == 42001:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
return await self.send_image(user_id,agent_id,media_id)
|
||||
return await self.send_image(user_id, agent_id, media_id)
|
||||
|
||||
if data['errcode'] != 0:
|
||||
raise Exception("Failed to send image: "+str(data))
|
||||
|
||||
async def send_private_msg(self,user_id:str, agent_id:int,content:str):
|
||||
raise Exception('Failed to send image: ' + str(data))
|
||||
|
||||
async def send_private_msg(self, user_id: str, agent_id: int, content: str):
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
|
||||
url = self.base_url+'/message/send?access_token='+self.access_token
|
||||
url = self.base_url + '/message/send?access_token=' + self.access_token
|
||||
async with httpx.AsyncClient() as client:
|
||||
params={
|
||||
"touser" : user_id,
|
||||
"msgtype" : "text",
|
||||
"agentid" : agent_id,
|
||||
"text" : {
|
||||
"content" : content,
|
||||
params = {
|
||||
'touser': user_id,
|
||||
'msgtype': 'text',
|
||||
'agentid': agent_id,
|
||||
'text': {
|
||||
'content': content,
|
||||
},
|
||||
"safe":0,
|
||||
"enable_id_trans": 0,
|
||||
"enable_duplicate_check": 0,
|
||||
"duplicate_check_interval": 1800
|
||||
'safe': 0,
|
||||
'enable_id_trans': 0,
|
||||
'enable_duplicate_check': 0,
|
||||
'duplicate_check_interval': 1800,
|
||||
}
|
||||
response = await client.post(url,json=params)
|
||||
response = await client.post(url, json=params)
|
||||
data = response.json()
|
||||
if data['errcode'] == 40014 or data['errcode'] == 42001:
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
return await self.send_private_msg(user_id,agent_id,content)
|
||||
return await self.send_private_msg(user_id, agent_id, content)
|
||||
if data['errcode'] != 0:
|
||||
raise Exception("Failed to send message: "+str(data))
|
||||
raise Exception('Failed to send message: ' + str(data))
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""
|
||||
处理回调请求,包括 GET 验证和 POST 消息接收。
|
||||
"""
|
||||
try:
|
||||
msg_signature = request.args.get('msg_signature')
|
||||
timestamp = request.args.get('timestamp')
|
||||
nonce = request.args.get('nonce')
|
||||
|
||||
msg_signature = request.args.get("msg_signature")
|
||||
timestamp = request.args.get("timestamp")
|
||||
nonce = request.args.get("nonce")
|
||||
|
||||
if request.method == "GET":
|
||||
echostr = request.args.get("echostr")
|
||||
ret, reply_echo_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
|
||||
if request.method == 'GET':
|
||||
echostr = request.args.get('echostr')
|
||||
ret, reply_echo_str = self.wxcpt.VerifyURL(
|
||||
msg_signature, timestamp, nonce, echostr
|
||||
)
|
||||
if ret != 0:
|
||||
raise Exception(f"验证失败,错误码: {ret}")
|
||||
raise Exception(f'验证失败,错误码: {ret}')
|
||||
return reply_echo_str
|
||||
|
||||
elif request.method == "POST":
|
||||
elif request.method == 'POST':
|
||||
encrypt_msg = await request.data
|
||||
ret, xml_msg = self.wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
|
||||
ret, xml_msg = self.wxcpt.DecryptMsg(
|
||||
encrypt_msg, msg_signature, timestamp, nonce
|
||||
)
|
||||
if ret != 0:
|
||||
raise Exception(f"消息解密失败,错误码: {ret}")
|
||||
raise Exception(f'消息解密失败,错误码: {ret}')
|
||||
|
||||
# 解析消息并处理
|
||||
message_data = await self.get_message(xml_msg)
|
||||
if message_data:
|
||||
event = WecomEvent.from_payload(message_data) # 转换为 WecomEvent 对象
|
||||
event = WecomEvent.from_payload(
|
||||
message_data
|
||||
) # 转换为 WecomEvent 对象
|
||||
if event:
|
||||
await self._handle_message(event)
|
||||
|
||||
return "success"
|
||||
return 'success'
|
||||
except Exception as e:
|
||||
return f"Error processing request: {str(e)}", 400
|
||||
return f'Error processing request: {str(e)}', 400
|
||||
|
||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||
"""
|
||||
@@ -194,11 +225,13 @@ class WecomClient():
|
||||
"""
|
||||
注册消息类型处理器。
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[[WecomEvent], None]):
|
||||
if msg_type not in self._message_handlers:
|
||||
self._message_handlers[msg_type] = []
|
||||
self._message_handlers[msg_type].append(func)
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
async def _handle_message(self, event: WecomEvent):
|
||||
@@ -216,38 +249,47 @@ class WecomClient():
|
||||
"""
|
||||
root = ET.fromstring(xml_msg)
|
||||
message_data = {
|
||||
"ToUserName": root.find("ToUserName").text,
|
||||
"FromUserName": root.find("FromUserName").text,
|
||||
"CreateTime": int(root.find("CreateTime").text),
|
||||
"MsgType": root.find("MsgType").text,
|
||||
"Content": root.find("Content").text if root.find("Content") is not None else None,
|
||||
"MsgId": int(root.find("MsgId").text) if root.find("MsgId") is not None else None,
|
||||
"AgentID": int(root.find("AgentID").text) if root.find("AgentID") is not None else None,
|
||||
'ToUserName': root.find('ToUserName').text,
|
||||
'FromUserName': root.find('FromUserName').text,
|
||||
'CreateTime': int(root.find('CreateTime').text),
|
||||
'MsgType': root.find('MsgType').text,
|
||||
'Content': root.find('Content').text
|
||||
if root.find('Content') is not None
|
||||
else None,
|
||||
'MsgId': int(root.find('MsgId').text)
|
||||
if root.find('MsgId') is not None
|
||||
else None,
|
||||
'AgentID': int(root.find('AgentID').text)
|
||||
if root.find('AgentID') is not None
|
||||
else None,
|
||||
}
|
||||
if message_data["MsgType"] == "image":
|
||||
message_data["MediaId"] = root.find("MediaId").text if root.find("MediaId") is not None else None
|
||||
message_data["PicUrl"] = root.find("PicUrl").text if root.find("PicUrl") is not None else None
|
||||
|
||||
if message_data['MsgType'] == 'image':
|
||||
message_data['MediaId'] = (
|
||||
root.find('MediaId').text if root.find('MediaId') is not None else None
|
||||
)
|
||||
message_data['PicUrl'] = (
|
||||
root.find('PicUrl').text if root.find('PicUrl') is not None else None
|
||||
)
|
||||
|
||||
return message_data
|
||||
|
||||
|
||||
@staticmethod
|
||||
async def get_image_type(image_bytes: bytes) -> str:
|
||||
"""
|
||||
通过图片的magic numbers判断图片类型
|
||||
"""
|
||||
magic_numbers = {
|
||||
b'\xFF\xD8\xFF': 'jpg',
|
||||
b'\x89\x50\x4E\x47': 'png',
|
||||
b'\xff\xd8\xff': 'jpg',
|
||||
b'\x89\x50\x4e\x47': 'png',
|
||||
b'\x47\x49\x46': 'gif',
|
||||
b'\x42\x4D': 'bmp',
|
||||
b'\x00\x00\x01\x00': 'ico'
|
||||
b'\x42\x4d': 'bmp',
|
||||
b'\x00\x00\x01\x00': 'ico',
|
||||
}
|
||||
|
||||
|
||||
for magic, ext in magic_numbers.items():
|
||||
if image_bytes.startswith(magic):
|
||||
return ext
|
||||
return 'jpg' # 默认返回jpg
|
||||
|
||||
|
||||
async def upload_to_work(self, image: platform_message.Image):
|
||||
"""
|
||||
@@ -256,9 +298,14 @@ class WecomClient():
|
||||
if not await self.check_access_token():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
|
||||
url = self.base_url + '/media/upload?access_token=' + self.access_token + '&type=file'
|
||||
url = (
|
||||
self.base_url
|
||||
+ '/media/upload?access_token='
|
||||
+ self.access_token
|
||||
+ '&type=file'
|
||||
)
|
||||
file_bytes = None
|
||||
file_name = "uploaded_file.txt"
|
||||
file_name = 'uploaded_file.txt'
|
||||
|
||||
# 获取文件的二进制数据
|
||||
if image.path:
|
||||
@@ -277,20 +324,22 @@ class WecomClient():
|
||||
padded_base64 = base64_data + '=' * padding
|
||||
file_bytes = base64.b64decode(padded_base64)
|
||||
except binascii.Error as e:
|
||||
raise ValueError(f"Invalid base64 string: {str(e)}")
|
||||
raise ValueError(f'Invalid base64 string: {str(e)}')
|
||||
else:
|
||||
raise ValueError("image对象出错")
|
||||
raise ValueError('image对象出错')
|
||||
|
||||
# 设置 multipart/form-data 格式的文件
|
||||
boundary = "-------------------------acebdf13572468"
|
||||
headers = {
|
||||
'Content-Type': f'multipart/form-data; boundary={boundary}'
|
||||
}
|
||||
boundary = '-------------------------acebdf13572468'
|
||||
headers = {'Content-Type': f'multipart/form-data; boundary={boundary}'}
|
||||
body = (
|
||||
f"--{boundary}\r\n"
|
||||
f"Content-Disposition: form-data; name=\"media\"; filename=\"{file_name}\"; filelength={len(file_bytes)}\r\n"
|
||||
f"Content-Type: application/octet-stream\r\n\r\n"
|
||||
).encode('utf-8') + file_bytes + f"\r\n--{boundary}--\r\n".encode('utf-8')
|
||||
(
|
||||
f'--{boundary}\r\n'
|
||||
f'Content-Disposition: form-data; name="media"; filename="{file_name}"; filelength={len(file_bytes)}\r\n'
|
||||
f'Content-Type: application/octet-stream\r\n\r\n'
|
||||
).encode('utf-8')
|
||||
+ file_bytes
|
||||
+ f'\r\n--{boundary}--\r\n'.encode('utf-8')
|
||||
)
|
||||
|
||||
# 上传文件
|
||||
async with httpx.AsyncClient() as client:
|
||||
@@ -300,19 +349,18 @@ class WecomClient():
|
||||
self.access_token = await self.get_access_token(self.secret)
|
||||
media_id = await self.upload_to_work(image)
|
||||
if data.get('errcode', 0) != 0:
|
||||
raise Exception("failed to upload file")
|
||||
raise Exception('failed to upload file')
|
||||
|
||||
media_id = data.get('media_id')
|
||||
return media_id
|
||||
|
||||
async def download_image_to_bytes(self,url:str) -> bytes:
|
||||
async def download_image_to_bytes(self, url: str) -> bytes:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url)
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
#进行media_id的获取
|
||||
# 进行media_id的获取
|
||||
async def get_media_id(self, image: platform_message.Image):
|
||||
|
||||
media_id = await self.upload_to_work(image=image)
|
||||
return media_id
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# Author: jonyqin
|
||||
# Created Time: Thu 11 Sep 2014 01:53:58 PM CST
|
||||
# File Name: ierror.py
|
||||
# Description:定义错误码含义
|
||||
# Description:定义错误码含义
|
||||
#########################################################################
|
||||
WXBizMsgCrypt_OK = 0
|
||||
WXBizMsgCrypt_ValidateSignature_Error = -40001
|
||||
@@ -17,4 +17,4 @@ WXBizMsgCrypt_DecryptAES_Error = -40007
|
||||
WXBizMsgCrypt_IllegalBuffer = -40008
|
||||
WXBizMsgCrypt_EncodeBase64_Error = -40009
|
||||
WXBizMsgCrypt_DecodeBase64_Error = -40010
|
||||
WXBizMsgCrypt_GenReturnXml_Error = -40011
|
||||
WXBizMsgCrypt_GenReturnXml_Error = -40011
|
||||
|
||||
@@ -9,7 +9,7 @@ class WecomEvent(dict):
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def from_payload(payload: Dict[str, Any]) -> Optional["WecomEvent"]:
|
||||
def from_payload(payload: Dict[str, Any]) -> Optional['WecomEvent']:
|
||||
"""
|
||||
从企业微信事件数据构造 `WecomEvent` 对象。
|
||||
|
||||
@@ -34,14 +34,14 @@ class WecomEvent(dict):
|
||||
Returns:
|
||||
str: 事件类型。
|
||||
"""
|
||||
return self.get("MsgType", "")
|
||||
|
||||
return self.get('MsgType', '')
|
||||
|
||||
@property
|
||||
def picurl(self) -> str:
|
||||
"""
|
||||
图片链接
|
||||
"""
|
||||
return self.get("PicUrl")
|
||||
return self.get('PicUrl')
|
||||
|
||||
@property
|
||||
def detail_type(self) -> str:
|
||||
@@ -53,8 +53,8 @@ class WecomEvent(dict):
|
||||
Returns:
|
||||
str: 事件详细类型。
|
||||
"""
|
||||
if self.type == "event":
|
||||
return self.get("Event", "")
|
||||
if self.type == 'event':
|
||||
return self.get('Event', '')
|
||||
return self.type
|
||||
|
||||
@property
|
||||
@@ -65,7 +65,7 @@ class WecomEvent(dict):
|
||||
Returns:
|
||||
str: 事件名。
|
||||
"""
|
||||
return f"{self.type}.{self.detail_type}"
|
||||
return f'{self.type}.{self.detail_type}'
|
||||
|
||||
@property
|
||||
def user_id(self) -> Optional[str]:
|
||||
@@ -75,8 +75,8 @@ class WecomEvent(dict):
|
||||
Returns:
|
||||
Optional[str]: 用户 ID。
|
||||
"""
|
||||
return self.get("FromUserName")
|
||||
|
||||
return self.get('FromUserName')
|
||||
|
||||
@property
|
||||
def agent_id(self) -> Optional[int]:
|
||||
"""
|
||||
@@ -85,7 +85,7 @@ class WecomEvent(dict):
|
||||
Returns:
|
||||
Optional[int]: 机器人 ID。
|
||||
"""
|
||||
return self.get("AgentID")
|
||||
return self.get('AgentID')
|
||||
|
||||
@property
|
||||
def receiver_id(self) -> Optional[str]:
|
||||
@@ -95,7 +95,7 @@ class WecomEvent(dict):
|
||||
Returns:
|
||||
Optional[str]: 接收者 ID。
|
||||
"""
|
||||
return self.get("ToUserName")
|
||||
return self.get('ToUserName')
|
||||
|
||||
@property
|
||||
def message_id(self) -> Optional[str]:
|
||||
@@ -105,7 +105,7 @@ class WecomEvent(dict):
|
||||
Returns:
|
||||
Optional[str]: 消息 ID。
|
||||
"""
|
||||
return self.get("MsgId")
|
||||
return self.get('MsgId')
|
||||
|
||||
@property
|
||||
def message(self) -> Optional[str]:
|
||||
@@ -115,7 +115,7 @@ class WecomEvent(dict):
|
||||
Returns:
|
||||
Optional[str]: 消息内容。
|
||||
"""
|
||||
return self.get("Content")
|
||||
return self.get('Content')
|
||||
|
||||
@property
|
||||
def media_id(self) -> Optional[str]:
|
||||
@@ -125,7 +125,7 @@ class WecomEvent(dict):
|
||||
Returns:
|
||||
Optional[str]: 媒体文件 ID。
|
||||
"""
|
||||
return self.get("MediaId")
|
||||
return self.get('MediaId')
|
||||
|
||||
@property
|
||||
def timestamp(self) -> Optional[int]:
|
||||
@@ -135,7 +135,7 @@ class WecomEvent(dict):
|
||||
Returns:
|
||||
Optional[int]: 时间戳。
|
||||
"""
|
||||
return self.get("CreateTime")
|
||||
return self.get('CreateTime')
|
||||
|
||||
@property
|
||||
def event_key(self) -> Optional[str]:
|
||||
@@ -145,7 +145,7 @@ class WecomEvent(dict):
|
||||
Returns:
|
||||
Optional[str]: 事件 Key。
|
||||
"""
|
||||
return self.get("EventKey")
|
||||
return self.get('EventKey')
|
||||
|
||||
def __getattr__(self, key: str) -> Optional[Any]:
|
||||
"""
|
||||
@@ -176,4 +176,4 @@ class WecomEvent(dict):
|
||||
Returns:
|
||||
str: 字符串表示。
|
||||
"""
|
||||
return f"<WecomEvent {super().__repr__()}>"
|
||||
return f'<WecomEvent {super().__repr__()}>'
|
||||
|
||||
Reference in New Issue
Block a user