From 4a02c531b20ae1e4ec86df8cbd0dc74c2d795aaa Mon Sep 17 00:00:00 2001 From: Alfonsxh Date: Tue, 28 Oct 2025 18:30:55 +0800 Subject: [PATCH] refactor: split WeCom callback handlers --- libs/wecom_ai_bot_api/api.py | 99 +++++++++++++++++++----------------- 1 file changed, 53 insertions(+), 46 deletions(-) diff --git a/libs/wecom_ai_bot_api/api.py b/libs/wecom_ai_bot_api/api.py index 41d379a6..9568eab4 100644 --- a/libs/wecom_ai_bot_api/api.py +++ b/libs/wecom_ai_bot_api/api.py @@ -295,7 +295,7 @@ class WecomBotClient: except Exception: await self.logger.error(traceback.format_exc()) - async def _handle_initial_message(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]: + async def _handle_post_initial_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]: """处理企业微信首次推送的消息,返回 stream_id 并开启流水线。 Args: @@ -324,7 +324,7 @@ class WecomBotClient: payload = self._build_stream_payload(session.stream_id, '', False) return await self._encrypt_and_reply(payload, nonce) - async def _handle_stream_refresh(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]: + async def _handle_post_followup_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]: """处理企业微信的流式刷新请求,按需返回增量片段。 Args: @@ -375,57 +375,64 @@ class WecomBotClient: await self.logger.info(f'{request.method} {request.url} {str(request.args)}') if request.method == 'GET': - # GET 用于验证回调 URL,有效期内直接返回微信给的 echostr - msg_signature = unquote(request.args.get('msg_signature', '')) - timestamp = unquote(request.args.get('timestamp', '')) - nonce = unquote(request.args.get('nonce', '')) - echostr = unquote(request.args.get('echostr', '')) + return await self._handle_get_callback() - if not all([msg_signature, timestamp, nonce, echostr]): - await self.logger.error('请求参数缺失') - return Response('缺少参数', status=400) + if request.method == 'POST': + return await self._handle_post_callback() - ret, decrypted_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr) - if ret != 0: - await self.logger.error('验证URL失败') - return Response('验证失败', status=403) - - return Response(decrypted_str, mimetype='text/plain') - - if request.method != 'POST': - return Response('', status=405) - - self.stream_sessions.cleanup() - - msg_signature = unquote(request.args.get('msg_signature', '')) - timestamp = unquote(request.args.get('timestamp', '')) - nonce = unquote(request.args.get('nonce', '')) - - encrypted_json = await request.get_json() - encrypted_msg = (encrypted_json or {}).get('encrypt', '') - if not encrypted_msg: - await self.logger.error("请求体中缺少 'encrypt' 字段") - return Response('Bad Request', status=400) - - xml_post_data = f"" - ret, decrypted_xml = self.wxcpt.DecryptMsg(xml_post_data, msg_signature, timestamp, nonce) - if ret != 0: - await self.logger.error('解密失败') - return Response('解密失败', status=400) - - msg_json = json.loads(decrypted_xml) - - if msg_json.get('msgtype') == 'stream': - # 企业微信刷新请求:尝试从队列中取出增量回复 - return await self._handle_stream_refresh(msg_json, nonce) - - # 首次请求:快速返回 stream_id 并异步处理流水线 - return await self._handle_initial_message(msg_json, nonce) + return Response('', status=405) except Exception: await self.logger.error(traceback.format_exc()) return Response('Internal Server Error', status=500) + async def _handle_get_callback(self) -> tuple[Response, int] | Response: + """处理企业微信的 GET 验证请求。""" + + msg_signature = unquote(request.args.get('msg_signature', '')) + timestamp = unquote(request.args.get('timestamp', '')) + nonce = unquote(request.args.get('nonce', '')) + echostr = unquote(request.args.get('echostr', '')) + + if not all([msg_signature, timestamp, nonce, echostr]): + await self.logger.error('请求参数缺失') + return Response('缺少参数', status=400) + + ret, decrypted_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr) + if ret != 0: + await self.logger.error('验证URL失败') + return Response('验证失败', status=403) + + return Response(decrypted_str, mimetype='text/plain') + + async def _handle_post_callback(self) -> tuple[Response, int] | Response: + """处理企业微信的 POST 回调请求。""" + + self.stream_sessions.cleanup() + + msg_signature = unquote(request.args.get('msg_signature', '')) + timestamp = unquote(request.args.get('timestamp', '')) + nonce = unquote(request.args.get('nonce', '')) + + encrypted_json = await request.get_json() + encrypted_msg = (encrypted_json or {}).get('encrypt', '') + if not encrypted_msg: + await self.logger.error("请求体中缺少 'encrypt' 字段") + return Response('Bad Request', status=400) + + xml_post_data = f"" + ret, decrypted_xml = self.wxcpt.DecryptMsg(xml_post_data, msg_signature, timestamp, nonce) + if ret != 0: + await self.logger.error('解密失败') + return Response('解密失败', status=400) + + msg_json = json.loads(decrypted_xml) + + if msg_json.get('msgtype') == 'stream': + return await self._handle_post_followup_response(msg_json, nonce) + + return await self._handle_post_initial_response(msg_json, nonce) + async def get_message(self, msg_json): message_data = {}