mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 11:29:39 +08:00
Compare commits
1 Commits
2a87419fb2
...
feature/we
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a4995c8cd9 |
@@ -23,25 +23,20 @@ xml_template = """
|
||||
|
||||
|
||||
class OAClient:
|
||||
def __init__(self, token: str, EncodingAESKey: str, AppID: str, Appsecret: str, logger: None, unified_mode: bool = False):
|
||||
def __init__(self, token: str, EncodingAESKey: str, AppID: str, Appsecret: str, logger: None):
|
||||
self.token = token
|
||||
self.aes = EncodingAESKey
|
||||
self.appid = AppID
|
||||
self.appsecret = Appsecret
|
||||
self.base_url = 'https://api.weixin.qq.com'
|
||||
self.access_token = ''
|
||||
self.unified_mode = unified_mode
|
||||
self.app = Quart(__name__)
|
||||
|
||||
# 只有在非统一模式下才注册独立路由
|
||||
if not self.unified_mode:
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
self._message_handlers = {
|
||||
'example': [],
|
||||
}
|
||||
@@ -51,39 +46,19 @@ class OAClient:
|
||||
self.logger = logger
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""处理回调请求(独立端口模式,使用全局 request)。"""
|
||||
return await self._handle_callback_internal(request)
|
||||
|
||||
async def handle_unified_webhook(self, req):
|
||||
"""处理回调请求(统一 webhook 模式,显式传递 request)。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self._handle_callback_internal(req)
|
||||
|
||||
async def _handle_callback_internal(self, req):
|
||||
"""处理回调请求的内部实现,包括 GET 验证和 POST 消息接收。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
"""
|
||||
try:
|
||||
# 每隔100毫秒查询是否生成ai回答
|
||||
start_time = time.time()
|
||||
signature = req.args.get('signature', '')
|
||||
timestamp = req.args.get('timestamp', '')
|
||||
nonce = req.args.get('nonce', '')
|
||||
echostr = req.args.get('echostr', '')
|
||||
msg_signature = req.args.get('msg_signature', '')
|
||||
signature = request.args.get('signature', '')
|
||||
timestamp = request.args.get('timestamp', '')
|
||||
nonce = request.args.get('nonce', '')
|
||||
echostr = request.args.get('echostr', '')
|
||||
msg_signature = request.args.get('msg_signature', '')
|
||||
if msg_signature is None:
|
||||
await self.logger.error('msg_signature不在请求体中')
|
||||
raise Exception('msg_signature不在请求体中')
|
||||
|
||||
if req.method == 'GET':
|
||||
if request.method == 'GET':
|
||||
# 校验签名
|
||||
check_str = ''.join(sorted([self.token, timestamp, nonce]))
|
||||
check_signature = hashlib.sha1(check_str.encode('utf-8')).hexdigest()
|
||||
@@ -93,8 +68,8 @@ class OAClient:
|
||||
else:
|
||||
await self.logger.error('拒绝请求')
|
||||
raise Exception('拒绝请求')
|
||||
elif req.method == 'POST':
|
||||
encryt_msg = await req.data
|
||||
elif request.method == 'POST':
|
||||
encryt_msg = await request.data
|
||||
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
|
||||
ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce)
|
||||
xml_msg = xml_msg.decode('utf-8')
|
||||
@@ -207,7 +182,6 @@ class OAClientForLongerResponse:
|
||||
Appsecret: str,
|
||||
LoadingMessage: str,
|
||||
logger: None,
|
||||
unified_mode: bool = False,
|
||||
):
|
||||
self.token = token
|
||||
self.aes = EncodingAESKey
|
||||
@@ -215,18 +189,13 @@ class OAClientForLongerResponse:
|
||||
self.appsecret = Appsecret
|
||||
self.base_url = 'https://api.weixin.qq.com'
|
||||
self.access_token = ''
|
||||
self.unified_mode = unified_mode
|
||||
self.app = Quart(__name__)
|
||||
|
||||
# 只有在非统一模式下才注册独立路由
|
||||
if not self.unified_mode:
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
self._message_handlers = {
|
||||
'example': [],
|
||||
}
|
||||
@@ -237,44 +206,24 @@ class OAClientForLongerResponse:
|
||||
self.logger = logger
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""处理回调请求(独立端口模式,使用全局 request)。"""
|
||||
return await self._handle_callback_internal(request)
|
||||
|
||||
async def handle_unified_webhook(self, req):
|
||||
"""处理回调请求(统一 webhook 模式,显式传递 request)。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self._handle_callback_internal(req)
|
||||
|
||||
async def _handle_callback_internal(self, req):
|
||||
"""处理回调请求的内部实现,包括 GET 验证和 POST 消息接收。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
"""
|
||||
try:
|
||||
signature = req.args.get('signature', '')
|
||||
timestamp = req.args.get('timestamp', '')
|
||||
nonce = req.args.get('nonce', '')
|
||||
echostr = req.args.get('echostr', '')
|
||||
msg_signature = req.args.get('msg_signature', '')
|
||||
signature = request.args.get('signature', '')
|
||||
timestamp = request.args.get('timestamp', '')
|
||||
nonce = request.args.get('nonce', '')
|
||||
echostr = request.args.get('echostr', '')
|
||||
msg_signature = request.args.get('msg_signature', '')
|
||||
|
||||
if msg_signature is None:
|
||||
await self.logger.error('msg_signature不在请求体中')
|
||||
raise Exception('msg_signature不在请求体中')
|
||||
|
||||
if req.method == 'GET':
|
||||
if request.method == 'GET':
|
||||
check_str = ''.join(sorted([self.token, timestamp, nonce]))
|
||||
check_signature = hashlib.sha1(check_str.encode('utf-8')).hexdigest()
|
||||
return echostr if check_signature == signature else '拒绝请求'
|
||||
|
||||
elif req.method == 'POST':
|
||||
encryt_msg = await req.data
|
||||
elif request.method == 'POST':
|
||||
encryt_msg = await request.data
|
||||
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
|
||||
ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce)
|
||||
xml_msg = xml_msg.decode('utf-8')
|
||||
|
||||
@@ -10,20 +10,38 @@ import traceback
|
||||
from cryptography.hazmat.primitives.asymmetric import ed25519
|
||||
|
||||
|
||||
def handle_validation(body: dict, bot_secret: str):
|
||||
# bot正确的secert是32位的,此处仅为了适配演示demo
|
||||
while len(bot_secret) < 32:
|
||||
bot_secret = bot_secret * 2
|
||||
bot_secret = bot_secret[:32]
|
||||
# 实际使用场景中以上三行内容可清除
|
||||
|
||||
seed_bytes = bot_secret.encode()
|
||||
|
||||
signing_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed_bytes)
|
||||
|
||||
msg = body['d']['event_ts'] + body['d']['plain_token']
|
||||
msg_bytes = msg.encode()
|
||||
|
||||
signature = signing_key.sign(msg_bytes)
|
||||
|
||||
signature_hex = signature.hex()
|
||||
|
||||
response = {'plain_token': body['d']['plain_token'], 'signature': signature_hex}
|
||||
|
||||
return response
|
||||
|
||||
|
||||
class QQOfficialClient:
|
||||
def __init__(self, secret: str, token: str, app_id: str, logger: None, unified_mode: bool = False):
|
||||
self.unified_mode = unified_mode
|
||||
def __init__(self, secret: str, token: str, app_id: str, logger: None):
|
||||
self.app = Quart(__name__)
|
||||
|
||||
# 只有在非统一模式下才注册独立路由
|
||||
if not self.unified_mode:
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
self.secret = secret
|
||||
self.token = token
|
||||
self.app_id = app_id
|
||||
@@ -64,45 +82,18 @@ class QQOfficialClient:
|
||||
raise Exception(f'获取access_token失败: {e}')
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""处理回调请求(独立端口模式,使用全局 request)"""
|
||||
return await self._handle_callback_internal(request)
|
||||
|
||||
async def handle_unified_webhook(self, req):
|
||||
"""处理回调请求(统一 webhook 模式,显式传递 request)。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self._handle_callback_internal(req)
|
||||
|
||||
async def _handle_callback_internal(self, req):
|
||||
"""处理回调请求的内部实现。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
"""
|
||||
"""处理回调请求"""
|
||||
try:
|
||||
|
||||
body = await req.get_data()
|
||||
|
||||
print(f'[QQ Official] Received request, body length: {len(body)}')
|
||||
|
||||
if not body or len(body) == 0:
|
||||
print('[QQ Official] Received empty body, might be health check or GET request')
|
||||
return {'code': 0, 'message': 'ok'}, 200
|
||||
|
||||
# 读取请求数据
|
||||
body = await request.get_data()
|
||||
payload = json.loads(body)
|
||||
|
||||
|
||||
# 验证是否为回调验证请求
|
||||
if payload.get('op') == 13:
|
||||
validation_data = payload.get('d')
|
||||
if not validation_data:
|
||||
return {'error': "missing 'd' field"}, 400
|
||||
response = await self.verify(validation_data)
|
||||
return response, 200
|
||||
# 生成签名
|
||||
response = handle_validation(payload, self.secret)
|
||||
|
||||
return response
|
||||
|
||||
if payload.get('op') == 0:
|
||||
message_data = await self.get_message(payload)
|
||||
@@ -113,7 +104,6 @@ class QQOfficialClient:
|
||||
return {'code': 0, 'message': 'success'}
|
||||
|
||||
except Exception as e:
|
||||
print(f'[QQ Official] ERROR: {traceback.format_exc()}')
|
||||
await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
|
||||
return {'error': str(e)}, 400
|
||||
|
||||
@@ -271,26 +261,3 @@ class QQOfficialClient:
|
||||
if self.access_token_expiry_time is None:
|
||||
return True
|
||||
return time.time() > self.access_token_expiry_time
|
||||
|
||||
async def repeat_seed(self, bot_secret: str, target_size: int = 32) -> bytes:
|
||||
seed = bot_secret
|
||||
while len(seed) < target_size:
|
||||
seed *= 2
|
||||
return seed[:target_size].encode("utf-8")
|
||||
|
||||
async def verify(self, validation_payload: dict):
|
||||
seed = await self.repeat_seed(self.secret)
|
||||
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed)
|
||||
|
||||
event_ts = validation_payload.get("event_ts", "")
|
||||
plain_token = validation_payload.get("plain_token", "")
|
||||
msg = event_ts + plain_token
|
||||
|
||||
# sign
|
||||
signature = private_key.sign(msg.encode()).hex()
|
||||
|
||||
response = {
|
||||
"plain_token": plain_token,
|
||||
"signature": signature,
|
||||
}
|
||||
return response
|
||||
|
||||
@@ -8,19 +8,14 @@ import langbot_plugin.api.entities.builtin.platform.events as platform_events
|
||||
|
||||
|
||||
class SlackClient:
|
||||
def __init__(self, bot_token: str, signing_secret: str, logger: None, unified_mode: bool = False):
|
||||
def __init__(self, bot_token: str, signing_secret: str, logger: None):
|
||||
self.bot_token = bot_token
|
||||
self.signing_secret = signing_secret
|
||||
self.unified_mode = unified_mode
|
||||
self.app = Quart(__name__)
|
||||
self.client = AsyncWebClient(self.bot_token)
|
||||
|
||||
# 只有在非统一模式下才注册独立路由
|
||||
if not self.unified_mode:
|
||||
self.app.add_url_rule(
|
||||
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']
|
||||
)
|
||||
|
||||
self.app.add_url_rule(
|
||||
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']
|
||||
)
|
||||
self._message_handlers = {
|
||||
'example': [],
|
||||
}
|
||||
@@ -28,28 +23,8 @@ class SlackClient:
|
||||
self.logger = logger
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""处理回调请求(独立端口模式,使用全局 request)"""
|
||||
return await self._handle_callback_internal(request)
|
||||
|
||||
async def handle_unified_webhook(self, req):
|
||||
"""处理回调请求(统一 webhook 模式,显式传递 request)。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self._handle_callback_internal(req)
|
||||
|
||||
async def _handle_callback_internal(self, req):
|
||||
"""处理回调请求的内部实现。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
"""
|
||||
try:
|
||||
body = await req.get_data()
|
||||
body = await request.get_data()
|
||||
data = json.loads(body)
|
||||
if 'type' in data:
|
||||
if data['type'] == 'url_verification':
|
||||
|
||||
@@ -1,477 +1,189 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Optional
|
||||
from urllib.parse import unquote
|
||||
import hashlib
|
||||
import traceback
|
||||
|
||||
import httpx
|
||||
from Crypto.Cipher import AES
|
||||
from quart import Quart, request, Response, jsonify
|
||||
|
||||
from libs.wecom_ai_bot_api import wecombotevent
|
||||
from libs.wecom_ai_bot_api.WXBizMsgCrypt3 import WXBizMsgCrypt
|
||||
from quart import Quart, request, Response, jsonify
|
||||
import langbot_plugin.api.entities.builtin.platform.message as platform_message
|
||||
import asyncio
|
||||
from libs.wecom_ai_bot_api import wecombotevent
|
||||
from typing import Callable
|
||||
import base64
|
||||
from Crypto.Cipher import AES
|
||||
from pkg.platform.logger import EventLogger
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamChunk:
|
||||
"""描述单次推送给企业微信的流式片段。"""
|
||||
|
||||
# 需要返回给企业微信的文本内容
|
||||
content: str
|
||||
|
||||
# 标记是否为最终片段,对应企业微信协议里的 finish 字段
|
||||
is_final: bool = False
|
||||
|
||||
# 预留额外元信息,未来支持多模态扩展时可使用
|
||||
meta: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass
|
||||
class StreamSession:
|
||||
"""维护一次企业微信流式会话的上下文。"""
|
||||
|
||||
# 企业微信要求的 stream_id,用于标识后续刷新请求
|
||||
stream_id: str
|
||||
|
||||
# 原始消息的 msgid,便于与流水线消息对应
|
||||
msg_id: str
|
||||
|
||||
# 群聊会话标识(单聊时为空)
|
||||
chat_id: Optional[str]
|
||||
|
||||
# 触发消息的发送者
|
||||
user_id: Optional[str]
|
||||
|
||||
# 会话创建时间
|
||||
created_at: float = field(default_factory=time.time)
|
||||
|
||||
# 最近一次被访问的时间,cleanup 依据该值判断过期
|
||||
last_access: float = field(default_factory=time.time)
|
||||
|
||||
# 将流水线增量结果缓存到队列,刷新请求逐条消费
|
||||
queue: asyncio.Queue = field(default_factory=asyncio.Queue)
|
||||
|
||||
# 是否已经完成(收到最终片段)
|
||||
finished: bool = False
|
||||
|
||||
# 缓存最近一次片段,处理重试或超时兜底
|
||||
last_chunk: Optional[StreamChunk] = None
|
||||
|
||||
|
||||
class StreamSessionManager:
|
||||
"""管理 stream 会话的生命周期,并负责队列的生产消费。"""
|
||||
|
||||
def __init__(self, logger: EventLogger, ttl: int = 60) -> None:
|
||||
self.logger = logger
|
||||
|
||||
self.ttl = ttl # 超时时间(秒),超过该时间未被访问的会话会被清理由 cleanup
|
||||
self._sessions: dict[str, StreamSession] = {} # stream_id -> StreamSession 映射
|
||||
self._msg_index: dict[str, str] = {} # msgid -> stream_id 映射,便于流水线根据消息 ID 找到会话
|
||||
|
||||
def get_stream_id_by_msg(self, msg_id: str) -> Optional[str]:
|
||||
if not msg_id:
|
||||
return None
|
||||
return self._msg_index.get(msg_id)
|
||||
|
||||
def get_session(self, stream_id: str) -> Optional[StreamSession]:
|
||||
return self._sessions.get(stream_id)
|
||||
|
||||
def create_or_get(self, msg_json: dict[str, Any]) -> tuple[StreamSession, bool]:
|
||||
"""根据企业微信回调创建或获取会话。
|
||||
|
||||
Args:
|
||||
msg_json: 企业微信解密后的回调 JSON。
|
||||
|
||||
Returns:
|
||||
Tuple[StreamSession, bool]: `StreamSession` 为会话实例,`bool` 指示是否为新建会话。
|
||||
|
||||
Example:
|
||||
在首次回调中调用,得到 `is_new=True` 后再触发流水线。
|
||||
"""
|
||||
msg_id = msg_json.get('msgid', '')
|
||||
if msg_id and msg_id in self._msg_index:
|
||||
stream_id = self._msg_index[msg_id]
|
||||
session = self._sessions.get(stream_id)
|
||||
if session:
|
||||
session.last_access = time.time()
|
||||
return session, False
|
||||
|
||||
stream_id = str(uuid.uuid4())
|
||||
session = StreamSession(
|
||||
stream_id=stream_id,
|
||||
msg_id=msg_id,
|
||||
chat_id=msg_json.get('chatid'),
|
||||
user_id=msg_json.get('from', {}).get('userid'),
|
||||
)
|
||||
|
||||
if msg_id:
|
||||
self._msg_index[msg_id] = stream_id
|
||||
self._sessions[stream_id] = session
|
||||
return session, True
|
||||
|
||||
async def publish(self, stream_id: str, chunk: StreamChunk) -> bool:
|
||||
"""向 stream 队列写入新的增量片段。
|
||||
|
||||
Args:
|
||||
stream_id: 企业微信分配的流式会话 ID。
|
||||
chunk: 待发送的增量片段。
|
||||
|
||||
Returns:
|
||||
bool: 当流式队列存在并成功入队时返回 True。
|
||||
|
||||
Example:
|
||||
在收到模型增量后调用 `await manager.publish('sid', StreamChunk('hello'))`。
|
||||
"""
|
||||
session = self._sessions.get(stream_id)
|
||||
if not session:
|
||||
return False
|
||||
|
||||
session.last_access = time.time()
|
||||
session.last_chunk = chunk
|
||||
|
||||
try:
|
||||
session.queue.put_nowait(chunk)
|
||||
except asyncio.QueueFull:
|
||||
# 默认无界队列,此处兜底防御
|
||||
await session.queue.put(chunk)
|
||||
|
||||
if chunk.is_final:
|
||||
session.finished = True
|
||||
|
||||
return True
|
||||
|
||||
async def consume(self, stream_id: str, timeout: float = 0.5) -> Optional[StreamChunk]:
|
||||
"""从队列中取出一个片段,若超时返回 None。
|
||||
|
||||
Args:
|
||||
stream_id: 企业微信流式会话 ID。
|
||||
timeout: 取片段的最长等待时间(秒)。
|
||||
|
||||
Returns:
|
||||
Optional[StreamChunk]: 成功时返回片段,超时或会话不存在时返回 None。
|
||||
|
||||
Example:
|
||||
企业微信刷新到达时调用,若队列有数据则立即返回 `StreamChunk`。
|
||||
"""
|
||||
session = self._sessions.get(stream_id)
|
||||
if not session:
|
||||
return None
|
||||
|
||||
session.last_access = time.time()
|
||||
|
||||
try:
|
||||
chunk = await asyncio.wait_for(session.queue.get(), timeout)
|
||||
session.last_access = time.time()
|
||||
if chunk.is_final:
|
||||
session.finished = True
|
||||
return chunk
|
||||
except asyncio.TimeoutError:
|
||||
if session.finished and session.last_chunk:
|
||||
return session.last_chunk
|
||||
return None
|
||||
|
||||
def mark_finished(self, stream_id: str) -> None:
|
||||
session = self._sessions.get(stream_id)
|
||||
if session:
|
||||
session.finished = True
|
||||
session.last_access = time.time()
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""定期清理过期会话,防止队列与映射无上限累积。"""
|
||||
now = time.time()
|
||||
expired: list[str] = []
|
||||
for stream_id, session in self._sessions.items():
|
||||
if now - session.last_access > self.ttl:
|
||||
expired.append(stream_id)
|
||||
|
||||
for stream_id in expired:
|
||||
session = self._sessions.pop(stream_id, None)
|
||||
if not session:
|
||||
continue
|
||||
msg_id = session.msg_id
|
||||
if msg_id and self._msg_index.get(msg_id) == stream_id:
|
||||
self._msg_index.pop(msg_id, None)
|
||||
|
||||
|
||||
class WecomBotClient:
|
||||
def __init__(self, Token: str, EnCodingAESKey: str, Corpid: str, logger: EventLogger, unified_mode: bool = False):
|
||||
"""企业微信智能机器人客户端。
|
||||
|
||||
Args:
|
||||
Token: 企业微信回调验证使用的 token。
|
||||
EnCodingAESKey: 企业微信消息加解密密钥。
|
||||
Corpid: 企业 ID。
|
||||
logger: 日志记录器。
|
||||
unified_mode: 是否使用统一 webhook 模式(默认 False)。
|
||||
|
||||
Example:
|
||||
>>> client = WecomBotClient(Token='token', EnCodingAESKey='aeskey', Corpid='corp', logger=logger)
|
||||
"""
|
||||
|
||||
self.Token = Token
|
||||
self.EnCodingAESKey = EnCodingAESKey
|
||||
self.Corpid = Corpid
|
||||
def __init__(self,Token:str,EnCodingAESKey:str,Corpid:str,logger:EventLogger):
|
||||
self.Token=Token
|
||||
self.EnCodingAESKey=EnCodingAESKey
|
||||
self.Corpid=Corpid
|
||||
self.ReceiveId = ''
|
||||
self.unified_mode = unified_mode
|
||||
self.app = Quart(__name__)
|
||||
|
||||
# 只有在非统一模式下才注册独立路由
|
||||
if not self.unified_mode:
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['POST', 'GET']
|
||||
)
|
||||
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['POST','GET']
|
||||
)
|
||||
self._message_handlers = {
|
||||
'example': [],
|
||||
}
|
||||
self.user_stream_map = {}
|
||||
self.logger = logger
|
||||
self.generated_content: dict[str, str] = {}
|
||||
self.msg_id_map: dict[str, int] = {}
|
||||
self.stream_sessions = StreamSessionManager(logger=logger)
|
||||
self.stream_poll_timeout = 0.5
|
||||
|
||||
@staticmethod
|
||||
def _build_stream_payload(stream_id: str, content: str, finish: bool) -> dict[str, Any]:
|
||||
"""按照企业微信协议拼装返回报文。
|
||||
|
||||
Args:
|
||||
stream_id: 企业微信会话 ID。
|
||||
content: 推送的文本内容。
|
||||
finish: 是否为最终片段。
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: 可直接加密返回的 payload。
|
||||
|
||||
Example:
|
||||
组装 `{'msgtype': 'stream', 'stream': {'id': 'sid', ...}}` 结构。
|
||||
"""
|
||||
return {
|
||||
'msgtype': 'stream',
|
||||
'stream': {
|
||||
'id': stream_id,
|
||||
'finish': finish,
|
||||
'content': content,
|
||||
},
|
||||
}
|
||||
|
||||
async def _encrypt_and_reply(self, payload: dict[str, Any], nonce: str) -> tuple[Response, int]:
|
||||
"""对响应进行加密封装并返回给企业微信。
|
||||
|
||||
Args:
|
||||
payload: 待加密的响应内容。
|
||||
nonce: 企业微信回调参数中的 nonce。
|
||||
|
||||
Returns:
|
||||
Tuple[Response, int]: Quart Response 对象及状态码。
|
||||
|
||||
Example:
|
||||
在首包或刷新场景中调用以生成加密响应。
|
||||
"""
|
||||
reply_plain_str = json.dumps(payload, ensure_ascii=False)
|
||||
reply_timestamp = str(int(time.time()))
|
||||
ret, encrypt_text = self.wxcpt.EncryptMsg(reply_plain_str, nonce, reply_timestamp)
|
||||
if ret != 0:
|
||||
await self.logger.error(f'加密失败: {ret}')
|
||||
return jsonify({'error': 'encrypt_failed'}), 500
|
||||
|
||||
root = ET.fromstring(encrypt_text)
|
||||
encrypt = root.find('Encrypt').text
|
||||
resp = {
|
||||
'encrypt': encrypt,
|
||||
}
|
||||
return jsonify(resp), 200
|
||||
|
||||
async def _dispatch_event(self, event: wecombotevent.WecomBotEvent) -> None:
|
||||
"""异步触发流水线处理,避免阻塞首包响应。
|
||||
|
||||
Args:
|
||||
event: 由企业微信消息转换的内部事件对象。
|
||||
"""
|
||||
try:
|
||||
await self._handle_message(event)
|
||||
except Exception:
|
||||
await self.logger.error(traceback.format_exc())
|
||||
|
||||
async def _handle_post_initial_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
|
||||
"""处理企业微信首次推送的消息,返回 stream_id 并开启流水线。
|
||||
|
||||
Args:
|
||||
msg_json: 解密后的企业微信消息 JSON。
|
||||
nonce: 企业微信回调参数 nonce。
|
||||
|
||||
Returns:
|
||||
Tuple[Response, int]: Quart Response 及状态码。
|
||||
|
||||
Example:
|
||||
首次回调时调用,立即返回带 `stream_id` 的响应。
|
||||
"""
|
||||
session, is_new = self.stream_sessions.create_or_get(msg_json)
|
||||
|
||||
message_data = await self.get_message(msg_json)
|
||||
if message_data:
|
||||
message_data['stream_id'] = session.stream_id
|
||||
try:
|
||||
event = wecombotevent.WecomBotEvent(message_data)
|
||||
except Exception:
|
||||
await self.logger.error(traceback.format_exc())
|
||||
else:
|
||||
if is_new:
|
||||
asyncio.create_task(self._dispatch_event(event))
|
||||
|
||||
payload = self._build_stream_payload(session.stream_id, '', False)
|
||||
return await self._encrypt_and_reply(payload, nonce)
|
||||
|
||||
async def _handle_post_followup_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
|
||||
"""处理企业微信的流式刷新请求,按需返回增量片段。
|
||||
|
||||
Args:
|
||||
msg_json: 解密后的企业微信刷新请求。
|
||||
nonce: 企业微信回调参数 nonce。
|
||||
|
||||
Returns:
|
||||
Tuple[Response, int]: Quart Response 及状态码。
|
||||
|
||||
Example:
|
||||
在刷新请求中调用,按需返回增量片段。
|
||||
"""
|
||||
stream_info = msg_json.get('stream', {})
|
||||
stream_id = stream_info.get('id', '')
|
||||
if not stream_id:
|
||||
await self.logger.error('刷新请求缺少 stream.id')
|
||||
return await self._encrypt_and_reply(self._build_stream_payload('', '', True), nonce)
|
||||
|
||||
session = self.stream_sessions.get_session(stream_id)
|
||||
chunk = await self.stream_sessions.consume(stream_id, timeout=self.stream_poll_timeout)
|
||||
|
||||
if not chunk:
|
||||
cached_content = None
|
||||
if session and session.msg_id:
|
||||
cached_content = self.generated_content.pop(session.msg_id, None)
|
||||
if cached_content is not None:
|
||||
chunk = StreamChunk(content=cached_content, is_final=True)
|
||||
else:
|
||||
payload = self._build_stream_payload(stream_id, '', False)
|
||||
return await self._encrypt_and_reply(payload, nonce)
|
||||
|
||||
payload = self._build_stream_payload(stream_id, chunk.content, chunk.is_final)
|
||||
if chunk.is_final:
|
||||
self.stream_sessions.mark_finished(stream_id)
|
||||
return await self._encrypt_and_reply(payload, nonce)
|
||||
self.generated_content = {}
|
||||
self.msg_id_map = {}
|
||||
|
||||
async def sha1_signature(token: str, timestamp: str, nonce: str, encrypt: str) -> str:
|
||||
raw = "".join(sorted([token, timestamp, nonce, encrypt]))
|
||||
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""企业微信回调入口(独立端口模式,使用全局 request)。
|
||||
|
||||
Returns:
|
||||
Quart Response: 根据请求类型返回验证、首包或刷新结果。
|
||||
|
||||
Example:
|
||||
作为 Quart 路由处理函数直接注册并使用。
|
||||
"""
|
||||
return await self._handle_callback_internal(request)
|
||||
|
||||
async def handle_unified_webhook(self, req):
|
||||
"""处理回调请求(统一 webhook 模式,显式传递 request)。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self._handle_callback_internal(req)
|
||||
|
||||
async def _handle_callback_internal(self, req):
|
||||
"""处理回调请求的内部实现,包括 GET 验证和 POST 消息接收。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
"""
|
||||
try:
|
||||
self.wxcpt = WXBizMsgCrypt(self.Token, self.EnCodingAESKey, '')
|
||||
await self.logger.info(f'{req.method} {req.url} {str(req.args)}')
|
||||
self.wxcpt=WXBizMsgCrypt(self.Token,self.EnCodingAESKey,'')
|
||||
|
||||
if req.method == 'GET':
|
||||
return await self._handle_get_callback(req)
|
||||
if request.method == "GET":
|
||||
|
||||
if req.method == 'POST':
|
||||
return await self._handle_post_callback(req)
|
||||
msg_signature = unquote(request.args.get("msg_signature", ""))
|
||||
timestamp = unquote(request.args.get("timestamp", ""))
|
||||
nonce = unquote(request.args.get("nonce", ""))
|
||||
echostr = unquote(request.args.get("echostr", ""))
|
||||
|
||||
return Response('', status=405)
|
||||
if not all([msg_signature, timestamp, nonce, echostr]):
|
||||
await self.logger.error("请求参数缺失")
|
||||
return Response("缺少参数", status=400)
|
||||
|
||||
except Exception:
|
||||
ret, decrypted_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
|
||||
if ret != 0:
|
||||
|
||||
await self.logger.error("验证URL失败")
|
||||
return Response("验证失败", status=403)
|
||||
|
||||
return Response(decrypted_str, mimetype="text/plain")
|
||||
|
||||
elif request.method == "POST":
|
||||
msg_signature = unquote(request.args.get("msg_signature", ""))
|
||||
timestamp = unquote(request.args.get("timestamp", ""))
|
||||
nonce = unquote(request.args.get("nonce", ""))
|
||||
|
||||
try:
|
||||
timeout = 3
|
||||
interval = 0.1
|
||||
start_time = time.monotonic()
|
||||
encrypted_json = await request.get_json()
|
||||
encrypted_msg = encrypted_json.get("encrypt", "")
|
||||
if not encrypted_msg:
|
||||
await self.logger.error("请求体中缺少 'encrypt' 字段")
|
||||
|
||||
xml_post_data = f"<xml><Encrypt><![CDATA[{encrypted_msg}]]></Encrypt></xml>"
|
||||
ret, decrypted_xml = self.wxcpt.DecryptMsg(xml_post_data, msg_signature, timestamp, nonce)
|
||||
if ret != 0:
|
||||
await self.logger.error("解密失败")
|
||||
|
||||
|
||||
msg_json = json.loads(decrypted_xml)
|
||||
|
||||
from_user_id = msg_json.get("from", {}).get("userid")
|
||||
chatid = msg_json.get("chatid", "")
|
||||
|
||||
message_data = await self.get_message(msg_json)
|
||||
|
||||
|
||||
|
||||
if message_data:
|
||||
try:
|
||||
event = wecombotevent.WecomBotEvent(message_data)
|
||||
if event:
|
||||
await self._handle_message(event)
|
||||
except Exception as e:
|
||||
await self.logger.error(traceback.format_exc())
|
||||
print(traceback.format_exc())
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
if msg_json.get('chattype','') == 'single':
|
||||
if from_user_id in self.user_stream_map:
|
||||
stream_id = self.user_stream_map[from_user_id]
|
||||
else:
|
||||
stream_id =str(uuid.uuid4())
|
||||
self.user_stream_map[from_user_id] = stream_id
|
||||
|
||||
|
||||
else:
|
||||
|
||||
if chatid in self.user_stream_map:
|
||||
stream_id = self.user_stream_map[chatid]
|
||||
else:
|
||||
stream_id = str(uuid.uuid4())
|
||||
self.user_stream_map[chatid] = stream_id
|
||||
except Exception as e:
|
||||
await self.logger.error(traceback.format_exc())
|
||||
print(traceback.format_exc())
|
||||
while True:
|
||||
content = self.generated_content.pop(msg_json['msgid'],None)
|
||||
if content:
|
||||
reply_plain = {
|
||||
"msgtype": "stream",
|
||||
"stream": {
|
||||
"id": stream_id,
|
||||
"finish": True,
|
||||
"content": content
|
||||
}
|
||||
}
|
||||
reply_plain_str = json.dumps(reply_plain, ensure_ascii=False)
|
||||
|
||||
reply_timestamp = str(int(time.time()))
|
||||
ret, encrypt_text = self.wxcpt.EncryptMsg(reply_plain_str, nonce, reply_timestamp)
|
||||
if ret != 0:
|
||||
|
||||
await self.logger.error("加密失败"+str(ret))
|
||||
|
||||
|
||||
root = ET.fromstring(encrypt_text)
|
||||
encrypt = root.find("Encrypt").text
|
||||
resp = {
|
||||
"encrypt": encrypt,
|
||||
}
|
||||
return jsonify(resp), 200
|
||||
|
||||
if time.time() - start_time > timeout:
|
||||
break
|
||||
|
||||
await asyncio.sleep(interval)
|
||||
|
||||
if self.msg_id_map.get(message_data['msgid'], 1) == 3:
|
||||
await self.logger.error('请求失效:暂不支持智能机器人超过7秒的请求,如有需求,请联系 LangBot 团队。')
|
||||
return ''
|
||||
|
||||
except Exception as e:
|
||||
await self.logger.error(traceback.format_exc())
|
||||
print(traceback.format_exc())
|
||||
|
||||
except Exception as e:
|
||||
await self.logger.error(traceback.format_exc())
|
||||
return Response('Internal Server Error', status=500)
|
||||
print(traceback.format_exc())
|
||||
|
||||
async def _handle_get_callback(self, req) -> tuple[Response, int] | Response:
|
||||
"""处理企业微信的 GET 验证请求。"""
|
||||
|
||||
msg_signature = unquote(req.args.get('msg_signature', ''))
|
||||
timestamp = unquote(req.args.get('timestamp', ''))
|
||||
nonce = unquote(req.args.get('nonce', ''))
|
||||
echostr = unquote(req.args.get('echostr', ''))
|
||||
|
||||
if not all([msg_signature, timestamp, nonce, echostr]):
|
||||
await self.logger.error('请求参数缺失')
|
||||
return Response('缺少参数', status=400)
|
||||
|
||||
ret, decrypted_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
|
||||
if ret != 0:
|
||||
await self.logger.error('验证URL失败')
|
||||
return Response('验证失败', status=403)
|
||||
|
||||
return Response(decrypted_str, mimetype='text/plain')
|
||||
|
||||
async def _handle_post_callback(self, req) -> tuple[Response, int] | Response:
|
||||
"""处理企业微信的 POST 回调请求。"""
|
||||
|
||||
self.stream_sessions.cleanup()
|
||||
|
||||
msg_signature = unquote(req.args.get('msg_signature', ''))
|
||||
timestamp = unquote(req.args.get('timestamp', ''))
|
||||
nonce = unquote(req.args.get('nonce', ''))
|
||||
|
||||
encrypted_json = await req.get_json()
|
||||
encrypted_msg = (encrypted_json or {}).get('encrypt', '')
|
||||
if not encrypted_msg:
|
||||
await self.logger.error("请求体中缺少 'encrypt' 字段")
|
||||
return Response('Bad Request', status=400)
|
||||
|
||||
xml_post_data = f"<xml><Encrypt><![CDATA[{encrypted_msg}]]></Encrypt></xml>"
|
||||
ret, decrypted_xml = self.wxcpt.DecryptMsg(xml_post_data, msg_signature, timestamp, nonce)
|
||||
if ret != 0:
|
||||
await self.logger.error('解密失败')
|
||||
return Response('解密失败', status=400)
|
||||
|
||||
msg_json = json.loads(decrypted_xml)
|
||||
|
||||
if msg_json.get('msgtype') == 'stream':
|
||||
return await self._handle_post_followup_response(msg_json, nonce)
|
||||
|
||||
return await self._handle_post_initial_response(msg_json, nonce)
|
||||
|
||||
async def get_message(self, msg_json):
|
||||
|
||||
async def get_message(self,msg_json):
|
||||
message_data = {}
|
||||
|
||||
if msg_json.get('chattype', '') == 'single':
|
||||
if msg_json.get('chattype','') == 'single':
|
||||
message_data['type'] = 'single'
|
||||
elif msg_json.get('chattype', '') == 'group':
|
||||
elif msg_json.get('chattype','') == 'group':
|
||||
message_data['type'] = 'group'
|
||||
|
||||
if msg_json.get('msgtype') == 'text':
|
||||
message_data['content'] = msg_json.get('text', {}).get('content')
|
||||
message_data['content'] = msg_json.get('text',{}).get('content')
|
||||
elif msg_json.get('msgtype') == 'image':
|
||||
picurl = msg_json.get('image', {}).get('url', '')
|
||||
base64 = await self.download_url_to_base64(picurl, self.EnCodingAESKey)
|
||||
message_data['picurl'] = base64
|
||||
picurl = msg_json.get('image', {}).get('url','')
|
||||
base64 = await self.download_url_to_base64(picurl,self.EnCodingAESKey)
|
||||
message_data['picurl'] = base64
|
||||
elif msg_json.get('msgtype') == 'mixed':
|
||||
items = msg_json.get('mixed', {}).get('msg_item', [])
|
||||
texts = []
|
||||
@@ -485,27 +197,17 @@ class WecomBotClient:
|
||||
if texts:
|
||||
message_data['content'] = "".join(texts) # 拼接所有 text
|
||||
if picurl:
|
||||
base64 = await self.download_url_to_base64(picurl, self.EnCodingAESKey)
|
||||
message_data['picurl'] = base64 # 只保留第一个 image
|
||||
|
||||
# Extract user information
|
||||
from_info = msg_json.get('from', {})
|
||||
message_data['userid'] = from_info.get('userid', '')
|
||||
message_data['username'] = from_info.get('alias', '') or from_info.get('name', '') or from_info.get('userid', '')
|
||||
|
||||
# Extract chat/group information
|
||||
if msg_json.get('chattype', '') == 'group':
|
||||
message_data['chatid'] = msg_json.get('chatid', '')
|
||||
# Try to get group name if available
|
||||
message_data['chatname'] = msg_json.get('chatname', '') or msg_json.get('chatid', '')
|
||||
base64 = await self.download_url_to_base64(picurl,self.EnCodingAESKey)
|
||||
message_data['picurl'] = base64 # 只保留第一个 image
|
||||
|
||||
message_data['userid'] = msg_json.get('from', {}).get('userid', '')
|
||||
message_data['msgid'] = msg_json.get('msgid', '')
|
||||
|
||||
if msg_json.get('aibotid'):
|
||||
message_data['aibotid'] = msg_json.get('aibotid', '')
|
||||
|
||||
return message_data
|
||||
|
||||
|
||||
async def _handle_message(self, event: wecombotevent.WecomBotEvent):
|
||||
"""
|
||||
处理消息事件。
|
||||
@@ -521,46 +223,10 @@ class WecomBotClient:
|
||||
for handler in self._message_handlers[msg_type]:
|
||||
await handler(event)
|
||||
except Exception:
|
||||
print(traceback.format_exc())
|
||||
|
||||
async def push_stream_chunk(self, msg_id: str, content: str, is_final: bool = False) -> bool:
|
||||
"""将流水线片段推送到 stream 会话。
|
||||
|
||||
Args:
|
||||
msg_id: 原始企业微信消息 ID。
|
||||
content: 模型产生的片段内容。
|
||||
is_final: 是否为最终片段。
|
||||
|
||||
Returns:
|
||||
bool: 当成功写入流式队列时返回 True。
|
||||
|
||||
Example:
|
||||
在流水线 `reply_message_chunk` 中调用,将增量推送至企业微信。
|
||||
"""
|
||||
# 根据 msg_id 找到对应 stream 会话,如果不存在说明当前消息非流式
|
||||
stream_id = self.stream_sessions.get_stream_id_by_msg(msg_id)
|
||||
if not stream_id:
|
||||
return False
|
||||
|
||||
chunk = StreamChunk(content=content, is_final=is_final)
|
||||
await self.stream_sessions.publish(stream_id, chunk)
|
||||
if is_final:
|
||||
self.stream_sessions.mark_finished(stream_id)
|
||||
return True
|
||||
print(traceback.format_exc())
|
||||
|
||||
async def set_message(self, msg_id: str, content: str):
|
||||
"""兼容旧逻辑:若无法流式返回则缓存最终结果。
|
||||
|
||||
Args:
|
||||
msg_id: 企业微信消息 ID。
|
||||
content: 最终回复的文本内容。
|
||||
|
||||
Example:
|
||||
在非流式场景下缓存最终结果以备刷新时返回。
|
||||
"""
|
||||
handled = await self.push_stream_chunk(msg_id, content, is_final=True)
|
||||
if not handled:
|
||||
self.generated_content[msg_id] = content
|
||||
self.generated_content[msg_id] = content
|
||||
|
||||
def on_message(self, msg_type: str):
|
||||
def decorator(func: Callable[[wecombotevent.WecomBotEvent], None]):
|
||||
@@ -571,6 +237,7 @@ class WecomBotClient:
|
||||
|
||||
return decorator
|
||||
|
||||
|
||||
async def download_url_to_base64(self, download_url, encoding_aes_key):
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(download_url)
|
||||
@@ -580,22 +247,26 @@ class WecomBotClient:
|
||||
|
||||
encrypted_bytes = response.content
|
||||
|
||||
|
||||
aes_key = base64.b64decode(encoding_aes_key + "=") # base64 补齐
|
||||
iv = aes_key[:16]
|
||||
|
||||
|
||||
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
|
||||
decrypted = cipher.decrypt(encrypted_bytes)
|
||||
|
||||
|
||||
pad_len = decrypted[-1]
|
||||
decrypted = decrypted[:-pad_len]
|
||||
|
||||
if decrypted.startswith(b"\xff\xd8"): # JPEG
|
||||
|
||||
if decrypted.startswith(b"\xff\xd8"): # JPEG
|
||||
mime_type = "image/jpeg"
|
||||
elif decrypted.startswith(b"\x89PNG"): # PNG
|
||||
mime_type = "image/png"
|
||||
elif decrypted.startswith((b"GIF87a", b"GIF89a")): # GIF
|
||||
mime_type = "image/gif"
|
||||
elif decrypted.startswith(b"BM"): # BMP
|
||||
elif decrypted.startswith(b"BM"): # BMP
|
||||
mime_type = "image/bmp"
|
||||
elif decrypted.startswith(b"II*\x00") or decrypted.startswith(b"MM\x00*"): # TIFF
|
||||
mime_type = "image/tiff"
|
||||
@@ -605,9 +276,15 @@ class WecomBotClient:
|
||||
# 转 base64
|
||||
base64_str = base64.b64encode(decrypted).decode("utf-8")
|
||||
return f"data:{mime_type};base64,{base64_str}"
|
||||
|
||||
|
||||
async def run_task(self, host: str, port: int, *args, **kwargs):
|
||||
"""
|
||||
启动 Quart 应用。
|
||||
"""
|
||||
await self.app.run_task(host=host, port=port, *args, **kwargs)
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -22,21 +22,7 @@ class WecomBotEvent(dict):
|
||||
"""
|
||||
用户id
|
||||
"""
|
||||
return self.get('from', {}).get('userid', '') or self.get('userid', '')
|
||||
|
||||
@property
|
||||
def username(self) -> str:
|
||||
"""
|
||||
用户名称
|
||||
"""
|
||||
return self.get('username', '') or self.get('from', {}).get('alias', '') or self.get('from', {}).get('name', '') or self.userid
|
||||
|
||||
@property
|
||||
def chatname(self) -> str:
|
||||
"""
|
||||
群组名称
|
||||
"""
|
||||
return self.get('chatname', '') or str(self.chatid)
|
||||
return self.get('from', {}).get('userid', '')
|
||||
|
||||
@property
|
||||
def content(self) -> str:
|
||||
|
||||
@@ -21,7 +21,6 @@ class WecomClient:
|
||||
EncodingAESKey: str,
|
||||
contacts_secret: str,
|
||||
logger: None,
|
||||
unified_mode: bool = False,
|
||||
):
|
||||
self.corpid = corpid
|
||||
self.secret = secret
|
||||
@@ -32,18 +31,13 @@ class WecomClient:
|
||||
self.access_token = ''
|
||||
self.secret_for_contacts = contacts_secret
|
||||
self.logger = logger
|
||||
self.unified_mode = unified_mode
|
||||
self.app = Quart(__name__)
|
||||
|
||||
# 只有在非统一模式下才注册独立路由
|
||||
if not self.unified_mode:
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
|
||||
self.app.add_url_rule(
|
||||
'/callback/command',
|
||||
'handle_callback',
|
||||
self.handle_callback_request,
|
||||
methods=['GET', 'POST'],
|
||||
)
|
||||
self._message_handlers = {
|
||||
'example': [],
|
||||
}
|
||||
@@ -174,43 +168,25 @@ class WecomClient:
|
||||
raise Exception('Failed to send message: ' + str(data))
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""处理回调请求(独立端口模式,使用全局 request)。"""
|
||||
return await self._handle_callback_internal(request)
|
||||
|
||||
async def handle_unified_webhook(self, req):
|
||||
"""处理回调请求(统一 webhook 模式,显式传递 request)。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self._handle_callback_internal(req)
|
||||
|
||||
async def _handle_callback_internal(self, req):
|
||||
"""
|
||||
处理回调请求的内部实现,包括 GET 验证和 POST 消息接收。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
处理回调请求,包括 GET 验证和 POST 消息接收。
|
||||
"""
|
||||
try:
|
||||
msg_signature = req.args.get('msg_signature')
|
||||
timestamp = req.args.get('timestamp')
|
||||
nonce = req.args.get('nonce')
|
||||
msg_signature = request.args.get('msg_signature')
|
||||
timestamp = request.args.get('timestamp')
|
||||
nonce = request.args.get('nonce')
|
||||
|
||||
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid)
|
||||
if req.method == 'GET':
|
||||
echostr = req.args.get('echostr')
|
||||
if request.method == 'GET':
|
||||
echostr = request.args.get('echostr')
|
||||
ret, reply_echo_str = wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
|
||||
if ret != 0:
|
||||
await self.logger.error('验证失败')
|
||||
raise Exception(f'验证失败,错误码: {ret}')
|
||||
return reply_echo_str
|
||||
|
||||
elif req.method == 'POST':
|
||||
encrypt_msg = await req.data
|
||||
elif request.method == 'POST':
|
||||
encrypt_msg = await request.data
|
||||
ret, xml_msg = wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
|
||||
if ret != 0:
|
||||
await self.logger.error('消息解密失败')
|
||||
|
||||
@@ -13,7 +13,7 @@ import aiofiles
|
||||
|
||||
|
||||
class WecomCSClient:
|
||||
def __init__(self, corpid: str, secret: str, token: str, EncodingAESKey: str, logger: None, unified_mode: bool = False):
|
||||
def __init__(self, corpid: str, secret: str, token: str, EncodingAESKey: str, logger: None):
|
||||
self.corpid = corpid
|
||||
self.secret = secret
|
||||
self.access_token_for_contacts = ''
|
||||
@@ -22,15 +22,10 @@ class WecomCSClient:
|
||||
self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin'
|
||||
self.access_token = ''
|
||||
self.logger = logger
|
||||
self.unified_mode = unified_mode
|
||||
self.app = Quart(__name__)
|
||||
|
||||
# 只有在非统一模式下才注册独立路由
|
||||
if not self.unified_mode:
|
||||
self.app.add_url_rule(
|
||||
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']
|
||||
)
|
||||
|
||||
self.app.add_url_rule(
|
||||
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']
|
||||
)
|
||||
self._message_handlers = {
|
||||
'example': [],
|
||||
}
|
||||
@@ -197,45 +192,27 @@ class WecomCSClient:
|
||||
return data
|
||||
|
||||
async def handle_callback_request(self):
|
||||
"""处理回调请求(独立端口模式,使用全局 request)。"""
|
||||
return await self._handle_callback_internal(request)
|
||||
|
||||
async def handle_unified_webhook(self, req):
|
||||
"""处理回调请求(统一 webhook 模式,显式传递 request)。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self._handle_callback_internal(req)
|
||||
|
||||
async def _handle_callback_internal(self, req):
|
||||
"""
|
||||
处理回调请求的内部实现,包括 GET 验证和 POST 消息接收。
|
||||
|
||||
Args:
|
||||
req: Quart Request 对象
|
||||
处理回调请求,包括 GET 验证和 POST 消息接收。
|
||||
"""
|
||||
try:
|
||||
msg_signature = req.args.get('msg_signature')
|
||||
timestamp = req.args.get('timestamp')
|
||||
nonce = req.args.get('nonce')
|
||||
msg_signature = request.args.get('msg_signature')
|
||||
timestamp = request.args.get('timestamp')
|
||||
nonce = request.args.get('nonce')
|
||||
try:
|
||||
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid)
|
||||
except Exception as e:
|
||||
raise Exception(f'初始化失败,错误码: {e}')
|
||||
|
||||
if req.method == 'GET':
|
||||
echostr = req.args.get('echostr')
|
||||
if request.method == 'GET':
|
||||
echostr = request.args.get('echostr')
|
||||
ret, reply_echo_str = wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
|
||||
if ret != 0:
|
||||
raise Exception(f'验证失败,错误码: {ret}')
|
||||
return reply_echo_str
|
||||
|
||||
elif req.method == 'POST':
|
||||
encrypt_msg = await req.data
|
||||
elif request.method == 'POST':
|
||||
encrypt_msg = await request.data
|
||||
ret, xml_msg = wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
|
||||
if ret != 0:
|
||||
raise Exception(f'消息解密失败,错误码: {ret}')
|
||||
|
||||
229
pkg/api/http/controller/groups/pipelines/websocket.py
Normal file
229
pkg/api/http/controller/groups/pipelines/websocket.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""流水线调试 WebSocket 路由
|
||||
|
||||
提供基于 WebSocket 的实时双向通信,用于流水线调试。
|
||||
支持 person 和 group 两种会话类型的隔离。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import quart
|
||||
|
||||
from ... import group
|
||||
from ....service.websocket_pool import WebSocketConnection
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def handle_client_event(connection: WebSocketConnection, message: dict, ap):
|
||||
"""处理客户端发送的事件
|
||||
|
||||
Args:
|
||||
connection: WebSocket 连接对象
|
||||
message: 客户端消息 {'type': 'xxx', 'data': {...}}
|
||||
ap: Application 实例
|
||||
"""
|
||||
event_type = message.get('type')
|
||||
data = message.get('data', {})
|
||||
|
||||
pipeline_uuid = connection.pipeline_uuid
|
||||
session_type = connection.session_type
|
||||
|
||||
try:
|
||||
webchat_adapter = ap.platform_mgr.webchat_proxy_bot.adapter
|
||||
|
||||
if event_type == 'send_message':
|
||||
# 发送消息到指定会话
|
||||
message_chain_obj = data.get('message_chain', [])
|
||||
client_message_id = data.get('client_message_id')
|
||||
|
||||
if not message_chain_obj:
|
||||
await connection.send('error', {'error': 'message_chain is required', 'error_code': 'INVALID_REQUEST'})
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"Received send_message: pipeline={pipeline_uuid}, "
|
||||
f"session={session_type}, "
|
||||
f"client_msg_id={client_message_id}"
|
||||
)
|
||||
|
||||
# 调用 webchat_adapter.send_webchat_message
|
||||
# 消息将通过 reply_message_chunk 自动推送到 WebSocket
|
||||
result = None
|
||||
async for msg in webchat_adapter.send_webchat_message(
|
||||
pipeline_uuid=pipeline_uuid, session_type=session_type, message_chain_obj=message_chain_obj, is_stream=True
|
||||
):
|
||||
result = msg
|
||||
|
||||
# 发送确认
|
||||
if result:
|
||||
await connection.send(
|
||||
'message_sent',
|
||||
{
|
||||
'client_message_id': client_message_id,
|
||||
'server_message_id': result.get('id'),
|
||||
'timestamp': result.get('timestamp'),
|
||||
},
|
||||
)
|
||||
|
||||
elif event_type == 'load_history':
|
||||
# 加载指定会话的历史消息
|
||||
before_message_id = data.get('before_message_id')
|
||||
limit = data.get('limit', 50)
|
||||
|
||||
logger.info(f"Loading history: pipeline={pipeline_uuid}, session={session_type}, limit={limit}")
|
||||
|
||||
# 从对应会话获取历史消息
|
||||
messages = webchat_adapter.get_webchat_messages(pipeline_uuid, session_type)
|
||||
|
||||
# 简单分页:返回最后 limit 条
|
||||
if before_message_id:
|
||||
# TODO: 实现基于 message_id 的分页
|
||||
history_messages = messages[-limit:]
|
||||
else:
|
||||
history_messages = messages[-limit:] if len(messages) > limit else messages
|
||||
|
||||
await connection.send(
|
||||
'history', {'messages': history_messages, 'has_more': len(messages) > len(history_messages)}
|
||||
)
|
||||
|
||||
elif event_type == 'interrupt':
|
||||
# 中断消息
|
||||
message_id = data.get('message_id')
|
||||
logger.info(f"Interrupt requested: message_id={message_id}")
|
||||
|
||||
# TODO: 实现中断逻辑
|
||||
await connection.send('interrupted', {'message_id': message_id, 'partial_content': ''})
|
||||
|
||||
elif event_type == 'ping':
|
||||
# 心跳
|
||||
connection.last_ping = datetime.now()
|
||||
await connection.send('pong', {'timestamp': data.get('timestamp')})
|
||||
|
||||
else:
|
||||
logger.warning(f"Unknown event type: {event_type}")
|
||||
await connection.send('error', {'error': f'Unknown event type: {event_type}', 'error_code': 'UNKNOWN_EVENT'})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error handling event {event_type}: {e}", exc_info=True)
|
||||
await connection.send(
|
||||
'error',
|
||||
{'error': f'Internal server error: {str(e)}', 'error_code': 'INTERNAL_ERROR', 'details': {'event_type': event_type}},
|
||||
)
|
||||
|
||||
|
||||
@group.group_class('pipeline-websocket', '/api/v1/pipelines/<pipeline_uuid>/chat')
|
||||
class PipelineWebSocketRouterGroup(group.RouterGroup):
|
||||
"""流水线调试 WebSocket 路由组"""
|
||||
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/ws')
|
||||
async def websocket_handler(pipeline_uuid: str):
|
||||
"""WebSocket 连接处理 - 会话隔离
|
||||
|
||||
连接流程:
|
||||
1. 客户端建立 WebSocket 连接
|
||||
2. 客户端发送 connect 事件(携带 session_type 和 token)
|
||||
3. 服务端验证并创建连接对象
|
||||
4. 进入消息循环,处理客户端事件
|
||||
5. 断开时清理连接
|
||||
|
||||
Args:
|
||||
pipeline_uuid: 流水线 UUID
|
||||
"""
|
||||
websocket = quart.websocket._get_current_object()
|
||||
connection_id = str(uuid.uuid4())
|
||||
session_key = None
|
||||
connection = None
|
||||
|
||||
try:
|
||||
# 1. 等待客户端发送 connect 事件
|
||||
first_message = await websocket.receive_json()
|
||||
|
||||
if first_message.get('type') != 'connect':
|
||||
await websocket.send_json(
|
||||
{'type': 'error', 'data': {'error': 'First message must be connect event', 'error_code': 'INVALID_HANDSHAKE'}}
|
||||
)
|
||||
await websocket.close(1008)
|
||||
return
|
||||
|
||||
connect_data = first_message.get('data', {})
|
||||
session_type = connect_data.get('session_type')
|
||||
token = connect_data.get('token')
|
||||
|
||||
# 验证参数
|
||||
if session_type not in ['person', 'group']:
|
||||
await websocket.send_json(
|
||||
{'type': 'error', 'data': {'error': 'session_type must be person or group', 'error_code': 'INVALID_SESSION_TYPE'}}
|
||||
)
|
||||
await websocket.close(1008)
|
||||
return
|
||||
|
||||
# 验证 token
|
||||
if not token:
|
||||
await websocket.send_json(
|
||||
{'type': 'error', 'data': {'error': 'token is required', 'error_code': 'MISSING_TOKEN'}}
|
||||
)
|
||||
await websocket.close(1008)
|
||||
return
|
||||
|
||||
# 验证用户身份
|
||||
try:
|
||||
user = await self.ap.user_service.verify_token(token)
|
||||
if not user:
|
||||
await websocket.send_json({'type': 'error', 'data': {'error': 'Unauthorized', 'error_code': 'UNAUTHORIZED'}})
|
||||
await websocket.close(1008)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.error(f"Token verification failed: {e}")
|
||||
await websocket.send_json(
|
||||
{'type': 'error', 'data': {'error': 'Token verification failed', 'error_code': 'AUTH_ERROR'}}
|
||||
)
|
||||
await websocket.close(1008)
|
||||
return
|
||||
|
||||
# 2. 创建连接对象并加入连接池
|
||||
connection = WebSocketConnection(
|
||||
connection_id=connection_id,
|
||||
websocket=websocket,
|
||||
pipeline_uuid=pipeline_uuid,
|
||||
session_type=session_type,
|
||||
created_at=datetime.now(),
|
||||
last_ping=datetime.now(),
|
||||
)
|
||||
|
||||
session_key = connection.session_key
|
||||
ws_pool = self.ap.ws_pool
|
||||
ws_pool.add_connection(connection)
|
||||
|
||||
# 3. 发送连接成功事件
|
||||
await connection.send(
|
||||
'connected', {'connection_id': connection_id, 'session_type': session_type, 'pipeline_uuid': pipeline_uuid}
|
||||
)
|
||||
|
||||
logger.info(f"WebSocket connected: {connection_id} [pipeline={pipeline_uuid}, session={session_type}]")
|
||||
|
||||
# 4. 进入消息处理循环
|
||||
while True:
|
||||
try:
|
||||
message = await websocket.receive_json()
|
||||
await handle_client_event(connection, message, self.ap)
|
||||
except asyncio.CancelledError:
|
||||
logger.info(f"WebSocket connection cancelled: {connection_id}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error receiving message from {connection_id}: {e}")
|
||||
break
|
||||
|
||||
except quart.exceptions.WebsocketDisconnected:
|
||||
logger.info(f"WebSocket disconnected: {connection_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"WebSocket error for {connection_id}: {e}", exc_info=True)
|
||||
finally:
|
||||
# 清理连接
|
||||
if connection and session_key:
|
||||
ws_pool = self.ap.ws_pool
|
||||
await ws_pool.remove_connection(connection_id, session_key)
|
||||
logger.info(f"WebSocket connection cleaned up: {connection_id}")
|
||||
@@ -18,8 +18,7 @@ class BotsRouterGroup(group.RouterGroup):
|
||||
@self.route('/<bot_uuid>', methods=['GET', 'PUT', 'DELETE'])
|
||||
async def _(bot_uuid: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
# 返回运行时信息,包括webhook地址等
|
||||
bot = await self.ap.bot_service.get_runtime_bot_info(bot_uuid)
|
||||
bot = await self.ap.bot_service.get_bot(bot_uuid)
|
||||
if bot is None:
|
||||
return self.http_status(404, -1, 'bot not found')
|
||||
return self.success(data={'bot': bot})
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import quart
|
||||
import traceback
|
||||
|
||||
from .. import group
|
||||
|
||||
|
||||
@group.group_class('webhooks', '/bots')
|
||||
class WebhookRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('/<bot_uuid>', methods=['GET', 'POST'], auth_type=group.AuthType.NONE)
|
||||
async def handle_webhook(bot_uuid: str):
|
||||
"""处理 bot webhook 回调(无子路径)"""
|
||||
return await self._dispatch_webhook(bot_uuid, '')
|
||||
|
||||
@self.route('/<bot_uuid>/<path:path>', methods=['GET', 'POST'], auth_type=group.AuthType.NONE)
|
||||
async def handle_webhook_with_path(bot_uuid: str, path: str):
|
||||
"""处理 bot webhook 回调(带子路径)"""
|
||||
return await self._dispatch_webhook(bot_uuid, path)
|
||||
|
||||
async def _dispatch_webhook(self, bot_uuid: str, path: str):
|
||||
"""分发 webhook 请求到对应的 bot adapter
|
||||
|
||||
Args:
|
||||
bot_uuid: Bot 的 UUID
|
||||
path: 子路径(如果有的话)
|
||||
|
||||
Returns:
|
||||
适配器返回的响应
|
||||
"""
|
||||
try:
|
||||
|
||||
runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
|
||||
|
||||
if not runtime_bot:
|
||||
return quart.jsonify({'error': 'Bot not found'}), 404
|
||||
|
||||
if not runtime_bot.enable:
|
||||
return quart.jsonify({'error': 'Bot is disabled'}), 403
|
||||
|
||||
|
||||
if not hasattr(runtime_bot.adapter, 'handle_unified_webhook'):
|
||||
return quart.jsonify({'error': 'Adapter does not support unified webhook'}), 501
|
||||
|
||||
|
||||
response = await runtime_bot.adapter.handle_unified_webhook(
|
||||
bot_uuid=bot_uuid,
|
||||
path=path,
|
||||
request=quart.request,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Webhook dispatch error for bot {bot_uuid}: {traceback.format_exc()}')
|
||||
return quart.jsonify({'error': str(e)}), 500
|
||||
@@ -58,15 +58,6 @@ class BotService:
|
||||
if runtime_bot is not None:
|
||||
adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id
|
||||
|
||||
if persistence_bot['adapter'] in ['wecom', 'wecombot', 'officialaccount', 'qqofficial', 'slack','wecomcs']:
|
||||
api_port = self.ap.instance_config.data['api']['port']
|
||||
webhook_url = f"/bots/{bot_uuid}"
|
||||
adapter_runtime_values['webhook_url'] = webhook_url
|
||||
adapter_runtime_values['webhook_full_url'] = f"http://<Your-Server-IP>:{api_port}{webhook_url}"
|
||||
else:
|
||||
adapter_runtime_values['webhook_url'] = None
|
||||
adapter_runtime_values['webhook_full_url'] = None
|
||||
|
||||
persistence_bot['adapter_runtime_values'] = adapter_runtime_values
|
||||
|
||||
return persistence_bot
|
||||
|
||||
211
pkg/api/http/service/websocket_pool.py
Normal file
211
pkg/api/http/service/websocket_pool.py
Normal file
@@ -0,0 +1,211 @@
|
||||
"""WebSocket 连接池管理
|
||||
|
||||
用于管理流水线调试的 WebSocket 连接,支持会话隔离和消息广播。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import quart
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class WebSocketConnection:
|
||||
"""单个 WebSocket 连接"""
|
||||
|
||||
connection_id: str
|
||||
websocket: quart.websocket.WebSocket
|
||||
pipeline_uuid: str
|
||||
session_type: str # 'person' 或 'group'
|
||||
created_at: datetime
|
||||
last_ping: datetime
|
||||
|
||||
@property
|
||||
def session_key(self) -> str:
|
||||
"""会话唯一标识: pipeline_uuid:session_type"""
|
||||
return f"{self.pipeline_uuid}:{self.session_type}"
|
||||
|
||||
async def send(self, event_type: str, data: dict):
|
||||
"""发送事件到客户端
|
||||
|
||||
Args:
|
||||
event_type: 事件类型
|
||||
data: 事件数据
|
||||
"""
|
||||
try:
|
||||
await self.websocket.send_json({"type": event_type, "data": data})
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to send message to {self.connection_id}: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class WebSocketConnectionPool:
|
||||
"""WebSocket 连接池 - 按会话隔离
|
||||
|
||||
连接池结构:
|
||||
connections[session_key][connection_id] = WebSocketConnection
|
||||
其中 session_key = f"{pipeline_uuid}:{session_type}"
|
||||
|
||||
这样可以确保:
|
||||
- person 和 group 会话完全隔离
|
||||
- 不同 pipeline 的会话隔离
|
||||
- 同一会话的多个连接可以同步接收消息(多标签页)
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.connections: dict[str, dict[str, WebSocketConnection]] = {}
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
def add_connection(self, conn: WebSocketConnection):
|
||||
"""添加连接到指定会话
|
||||
|
||||
Args:
|
||||
conn: WebSocket 连接对象
|
||||
"""
|
||||
session_key = conn.session_key
|
||||
|
||||
if session_key not in self.connections:
|
||||
self.connections[session_key] = {}
|
||||
|
||||
self.connections[session_key][conn.connection_id] = conn
|
||||
|
||||
logger.info(
|
||||
f"WebSocket connection added: {conn.connection_id} "
|
||||
f"to session {session_key} "
|
||||
f"(total: {len(self.connections[session_key])} connections)"
|
||||
)
|
||||
|
||||
async def remove_connection(self, connection_id: str, session_key: str):
|
||||
"""从指定会话移除连接
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID
|
||||
session_key: 会话标识
|
||||
"""
|
||||
async with self._lock:
|
||||
if session_key in self.connections:
|
||||
conn = self.connections[session_key].pop(connection_id, None)
|
||||
|
||||
# 如果该会话没有连接了,清理会话
|
||||
if not self.connections[session_key]:
|
||||
del self.connections[session_key]
|
||||
|
||||
if conn:
|
||||
logger.info(
|
||||
f"WebSocket connection removed: {connection_id} "
|
||||
f"from session {session_key} "
|
||||
f"(remaining: {len(self.connections.get(session_key, {}))} connections)"
|
||||
)
|
||||
|
||||
def get_connection(self, connection_id: str, session_key: str) -> Optional[WebSocketConnection]:
|
||||
"""获取指定连接
|
||||
|
||||
Args:
|
||||
connection_id: 连接 ID
|
||||
session_key: 会话标识
|
||||
|
||||
Returns:
|
||||
WebSocketConnection 或 None
|
||||
"""
|
||||
return self.connections.get(session_key, {}).get(connection_id)
|
||||
|
||||
def get_connections_by_session(self, pipeline_uuid: str, session_type: str) -> list[WebSocketConnection]:
|
||||
"""获取指定会话的所有连接
|
||||
|
||||
Args:
|
||||
pipeline_uuid: 流水线 UUID
|
||||
session_type: 会话类型 ('person' 或 'group')
|
||||
|
||||
Returns:
|
||||
连接列表
|
||||
"""
|
||||
session_key = f"{pipeline_uuid}:{session_type}"
|
||||
return list(self.connections.get(session_key, {}).values())
|
||||
|
||||
async def broadcast_to_session(self, pipeline_uuid: str, session_type: str, event_type: str, data: dict):
|
||||
"""广播消息到指定会话的所有连接
|
||||
|
||||
Args:
|
||||
pipeline_uuid: 流水线 UUID
|
||||
session_type: 会话类型 ('person' 或 'group')
|
||||
event_type: 事件类型
|
||||
data: 事件数据
|
||||
"""
|
||||
connections = self.get_connections_by_session(pipeline_uuid, session_type)
|
||||
|
||||
if not connections:
|
||||
logger.debug(f"No connections for session {pipeline_uuid}:{session_type}, skipping broadcast")
|
||||
return
|
||||
|
||||
logger.debug(
|
||||
f"Broadcasting {event_type} to session {pipeline_uuid}:{session_type}, " f"{len(connections)} connections"
|
||||
)
|
||||
|
||||
# 并发发送到所有连接,忽略失败的连接
|
||||
results = await asyncio.gather(*[conn.send(event_type, data) for conn in connections], return_exceptions=True)
|
||||
|
||||
# 统计失败的连接
|
||||
failed_count = sum(1 for result in results if isinstance(result, Exception))
|
||||
if failed_count > 0:
|
||||
logger.warning(f"Failed to send to {failed_count}/{len(connections)} connections")
|
||||
|
||||
def get_all_sessions(self) -> list[str]:
|
||||
"""获取所有活跃会话的 session_key 列表
|
||||
|
||||
Returns:
|
||||
会话标识列表
|
||||
"""
|
||||
return list(self.connections.keys())
|
||||
|
||||
def get_connection_count(self, pipeline_uuid: str, session_type: str) -> int:
|
||||
"""获取指定会话的连接数量
|
||||
|
||||
Args:
|
||||
pipeline_uuid: 流水线 UUID
|
||||
session_type: 会话类型
|
||||
|
||||
Returns:
|
||||
连接数量
|
||||
"""
|
||||
session_key = f"{pipeline_uuid}:{session_type}"
|
||||
return len(self.connections.get(session_key, {}))
|
||||
|
||||
async def cleanup_stale_connections(self, timeout_seconds: int = 120):
|
||||
"""清理超时的连接
|
||||
|
||||
Args:
|
||||
timeout_seconds: 超时时间(秒)
|
||||
"""
|
||||
now = datetime.now()
|
||||
stale_connections = []
|
||||
|
||||
# 查找超时连接
|
||||
for session_key, session_conns in self.connections.items():
|
||||
for conn_id, conn in session_conns.items():
|
||||
elapsed = (now - conn.last_ping).total_seconds()
|
||||
if elapsed > timeout_seconds:
|
||||
stale_connections.append((conn_id, session_key))
|
||||
|
||||
# 移除超时连接
|
||||
for conn_id, session_key in stale_connections:
|
||||
logger.warning(f"Removing stale connection: {conn_id} from {session_key}")
|
||||
await self.remove_connection(conn_id, session_key)
|
||||
|
||||
# 尝试关闭 WebSocket
|
||||
try:
|
||||
conn = self.get_connection(conn_id, session_key)
|
||||
if conn:
|
||||
await conn.websocket.close(1000, "Connection timeout")
|
||||
except Exception as e:
|
||||
logger.error(f"Error closing stale connection {conn_id}: {e}")
|
||||
|
||||
if stale_connections:
|
||||
logger.info(f"Cleaned up {len(stale_connections)} stale connections")
|
||||
@@ -105,6 +105,10 @@ class Application:
|
||||
|
||||
storage_mgr: storagemgr.StorageMgr = None
|
||||
|
||||
# ========= WebSocket =========
|
||||
|
||||
ws_pool = None # WebSocketConnectionPool
|
||||
|
||||
# ========= HTTP Services =========
|
||||
|
||||
user_service: user_service.UserService = None
|
||||
|
||||
@@ -19,6 +19,7 @@ from ...api.http.service import model as model_service
|
||||
from ...api.http.service import pipeline as pipeline_service
|
||||
from ...api.http.service import bot as bot_service
|
||||
from ...api.http.service import knowledge as knowledge_service
|
||||
from ...api.http.service import websocket_pool
|
||||
from ...discover import engine as discover_engine
|
||||
from ...storage import mgr as storagemgr
|
||||
from ...utils import logcache
|
||||
@@ -87,10 +88,18 @@ class BuildAppStage(stage.BootingStage):
|
||||
await llm_tool_mgr_inst.initialize()
|
||||
ap.tool_mgr = llm_tool_mgr_inst
|
||||
|
||||
# Initialize WebSocket connection pool
|
||||
ws_pool_inst = websocket_pool.WebSocketConnectionPool()
|
||||
ap.ws_pool = ws_pool_inst
|
||||
|
||||
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
|
||||
await im_mgr_inst.initialize()
|
||||
ap.platform_mgr = im_mgr_inst
|
||||
|
||||
# Inject WebSocket pool into WebChatAdapter
|
||||
if hasattr(ap.platform_mgr, 'webchat_proxy_bot') and ap.platform_mgr.webchat_proxy_bot:
|
||||
ap.platform_mgr.webchat_proxy_bot.adapter.set_ws_pool(ws_pool_inst)
|
||||
|
||||
pipeline_mgr = pipelinemgr.PipelineManager(ap)
|
||||
await pipeline_mgr.initialize()
|
||||
ap.pipeline_mgr = pipeline_mgr
|
||||
|
||||
@@ -232,10 +232,6 @@ class PlatformManager:
|
||||
logger,
|
||||
)
|
||||
|
||||
# 如果 adapter 支持 set_bot_uuid 方法,设置 bot_uuid(用于统一 webhook)
|
||||
if hasattr(adapter_inst, 'set_bot_uuid'):
|
||||
adapter_inst.set_bot_uuid(bot_entity.uuid)
|
||||
|
||||
runtime_bot = RuntimeBot(ap=self.ap, bot_entity=bot_entity, adapter=adapter_inst, logger=logger)
|
||||
|
||||
await runtime_bot.initialize()
|
||||
|
||||
@@ -59,16 +59,14 @@ class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAd
|
||||
message_converter: OAMessageConverter = OAMessageConverter()
|
||||
event_converter: OAEventConverter = OAEventConverter()
|
||||
bot: typing.Union[OAClient, OAClientForLongerResponse] = pydantic.Field(exclude=True)
|
||||
bot_uuid: str = None
|
||||
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
# 校验必填项
|
||||
required_keys = ['token', 'EncodingAESKey', 'AppSecret', 'AppID', 'Mode']
|
||||
missing_keys = [k for k in required_keys if k not in config]
|
||||
if missing_keys:
|
||||
raise Exception(f'OfficialAccount 缺少配置项: {missing_keys}')
|
||||
|
||||
# 创建运行时 bot 对象,始终使用统一 webhook 模式
|
||||
|
||||
if config['Mode'] == 'drop':
|
||||
bot = OAClient(
|
||||
token=config['token'],
|
||||
@@ -76,7 +74,6 @@ class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAd
|
||||
Appsecret=config['AppSecret'],
|
||||
AppID=config['AppID'],
|
||||
logger=logger,
|
||||
unified_mode=True,
|
||||
)
|
||||
elif config['Mode'] == 'passive':
|
||||
bot = OAClientForLongerResponse(
|
||||
@@ -86,14 +83,13 @@ class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAd
|
||||
AppID=config['AppID'],
|
||||
LoadingMessage=config.get('LoadingMessage', ''),
|
||||
logger=logger,
|
||||
unified_mode=True,
|
||||
)
|
||||
else:
|
||||
raise KeyError('请设置微信公众号通信模式')
|
||||
|
||||
bot_account_id = config.get('AppID', '')
|
||||
|
||||
|
||||
|
||||
super().__init__(
|
||||
bot=bot,
|
||||
bot_account_id=bot_account_id,
|
||||
@@ -140,45 +136,16 @@ class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAd
|
||||
elif event_type == platform_events.GroupMessage:
|
||||
pass
|
||||
|
||||
def set_bot_uuid(self, bot_uuid: str):
|
||||
"""设置 bot UUID(用于生成 webhook URL)"""
|
||||
self.bot_uuid = bot_uuid
|
||||
|
||||
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
|
||||
"""处理统一 webhook 请求。
|
||||
|
||||
Args:
|
||||
bot_uuid: Bot 的 UUID
|
||||
path: 子路径(如果有的话)
|
||||
request: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self.bot.handle_unified_webhook(request)
|
||||
|
||||
async def run_async(self):
|
||||
# 统一 webhook 模式下,不启动独立的 Quart 应用
|
||||
# 保持运行但不启动独立端口
|
||||
|
||||
# 打印 webhook 回调地址
|
||||
if self.bot_uuid and hasattr(self.logger, 'ap'):
|
||||
try:
|
||||
api_port = self.logger.ap.instance_config.data['api']['port']
|
||||
webhook_url = f"http://127.0.0.1:{api_port}/bots/{self.bot_uuid}"
|
||||
webhook_url_public = f"http://<Your-Public-IP>:{api_port}/bots/{self.bot_uuid}"
|
||||
|
||||
await self.logger.info(f"微信公众号 Webhook 回调地址:")
|
||||
await self.logger.info(f" 本地地址: {webhook_url}")
|
||||
await self.logger.info(f" 公网地址: {webhook_url_public}")
|
||||
await self.logger.info(f"请在微信公众号后台配置此回调地址")
|
||||
except Exception as e:
|
||||
await self.logger.warning(f"无法生成 webhook URL: {e}")
|
||||
|
||||
async def keep_alive():
|
||||
async def shutdown_trigger_placeholder():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
await keep_alive()
|
||||
|
||||
await self.bot.run_task(
|
||||
host=self.config['host'],
|
||||
port=self.config['port'],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -53,6 +53,23 @@ spec:
|
||||
type: string
|
||||
required: true
|
||||
default: "AI正在思考中,请发送任意内容获取回复。"
|
||||
- name: host
|
||||
label:
|
||||
en_US: Host
|
||||
zh_Hans: 监听主机
|
||||
description:
|
||||
en_US: The host that Official Account listens on for Webhook connections.
|
||||
zh_Hans: 微信公众号监听的主机,除非你知道自己在做什么,否则请写 0.0.0.0
|
||||
type: string
|
||||
required: true
|
||||
default: 0.0.0.0
|
||||
- name: port
|
||||
label:
|
||||
en_US: Port
|
||||
zh_Hans: 监听端口
|
||||
type: integer
|
||||
required: true
|
||||
default: 2287
|
||||
execution:
|
||||
python:
|
||||
path: ./officialaccount.py
|
||||
|
||||
@@ -135,17 +135,12 @@ class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
|
||||
bot: QQOfficialClient
|
||||
config: dict
|
||||
bot_account_id: str
|
||||
bot_uuid: str = None
|
||||
message_converter: QQOfficialMessageConverter = QQOfficialMessageConverter()
|
||||
event_converter: QQOfficialEventConverter = QQOfficialEventConverter()
|
||||
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
bot = QQOfficialClient(
|
||||
app_id=config['appid'],
|
||||
secret=config['secret'],
|
||||
token=config['token'],
|
||||
logger=logger,
|
||||
unified_mode=True
|
||||
app_id=config['appid'], secret=config['secret'], token=config['token'], logger=logger
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
@@ -231,45 +226,16 @@ class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
|
||||
self.bot.on_message('GROUP_AT_MESSAGE_CREATE')(on_message)
|
||||
self.bot.on_message('AT_MESSAGE_CREATE')(on_message)
|
||||
|
||||
def set_bot_uuid(self, bot_uuid: str):
|
||||
"""设置 bot UUID(用于生成 webhook URL)"""
|
||||
self.bot_uuid = bot_uuid
|
||||
|
||||
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
|
||||
"""处理统一 webhook 请求。
|
||||
|
||||
Args:
|
||||
bot_uuid: Bot 的 UUID
|
||||
path: 子路径(如果有的话)
|
||||
request: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self.bot.handle_unified_webhook(request)
|
||||
|
||||
async def run_async(self):
|
||||
# 统一 webhook 模式下,不启动独立的 Quart 应用
|
||||
# 保持运行但不启动独立端口
|
||||
|
||||
# 打印 webhook 回调地址
|
||||
if self.bot_uuid and hasattr(self.logger, 'ap'):
|
||||
try:
|
||||
api_port = self.logger.ap.instance_config.data['api']['port']
|
||||
webhook_url = f"http://127.0.0.1:{api_port}/bots/{self.bot_uuid}"
|
||||
webhook_url_public = f"http://<Your-Public-IP>:{api_port}/bots/{self.bot_uuid}"
|
||||
|
||||
await self.logger.info(f"QQ 官方机器人 Webhook 回调地址:")
|
||||
await self.logger.info(f" 本地地址: {webhook_url}")
|
||||
await self.logger.info(f" 公网地址: {webhook_url_public}")
|
||||
await self.logger.info(f"请在 QQ 官方机器人后台配置此回调地址")
|
||||
except Exception as e:
|
||||
await self.logger.warning(f"无法生成 webhook URL: {e}")
|
||||
|
||||
async def keep_alive():
|
||||
async def shutdown_trigger_placeholder():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
await keep_alive()
|
||||
|
||||
await self.bot.run_task(
|
||||
host='0.0.0.0',
|
||||
port=self.config['port'],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -25,6 +25,13 @@ spec:
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
- name: port
|
||||
label:
|
||||
en_US: Port
|
||||
zh_Hans: 监听端口
|
||||
type: integer
|
||||
required: true
|
||||
default: 2284
|
||||
- name: token
|
||||
label:
|
||||
en_US: Token
|
||||
|
||||
@@ -86,12 +86,13 @@ class SlackEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
class SlackAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot: SlackClient
|
||||
bot_account_id: str
|
||||
bot_uuid: str = None
|
||||
message_converter: SlackMessageConverter = SlackMessageConverter()
|
||||
event_converter: SlackEventConverter = SlackEventConverter()
|
||||
config: dict
|
||||
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
self.config = config
|
||||
self.logger = logger
|
||||
required_keys = [
|
||||
'bot_token',
|
||||
'signing_secret',
|
||||
@@ -100,18 +101,8 @@ class SlackAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
if missing_keys:
|
||||
raise command_errors.ParamNotEnoughError('Slack机器人缺少相关配置项,请查看文档或联系管理员')
|
||||
|
||||
bot = SlackClient(
|
||||
bot_token=config['bot_token'],
|
||||
signing_secret=config['signing_secret'],
|
||||
logger=logger,
|
||||
unified_mode=True
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
bot=bot,
|
||||
bot_account_id=config['bot_token'],
|
||||
self.bot = SlackClient(
|
||||
bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger
|
||||
)
|
||||
|
||||
async def reply_message(
|
||||
@@ -157,45 +148,16 @@ class SlackAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
elif event_type == platform_events.GroupMessage:
|
||||
self.bot.on_message('channel')(on_message)
|
||||
|
||||
def set_bot_uuid(self, bot_uuid: str):
|
||||
"""设置 bot UUID(用于生成 webhook URL)"""
|
||||
self.bot_uuid = bot_uuid
|
||||
|
||||
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
|
||||
"""处理统一 webhook 请求。
|
||||
|
||||
Args:
|
||||
bot_uuid: Bot 的 UUID
|
||||
path: 子路径(如果有的话)
|
||||
request: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self.bot.handle_unified_webhook(request)
|
||||
|
||||
async def run_async(self):
|
||||
# 统一 webhook 模式下,不启动独立的 Quart 应用
|
||||
# 保持运行但不启动独立端口
|
||||
|
||||
# 打印 webhook 回调地址
|
||||
if self.bot_uuid and hasattr(self.logger, 'ap'):
|
||||
try:
|
||||
api_port = self.logger.ap.instance_config.data['api']['port']
|
||||
webhook_url = f"http://127.0.0.1:{api_port}/bots/{self.bot_uuid}"
|
||||
webhook_url_public = f"http://<Your-Public-IP>:{api_port}/bots/{self.bot_uuid}"
|
||||
|
||||
await self.logger.info(f"Slack 机器人 Webhook 回调地址:")
|
||||
await self.logger.info(f" 本地地址: {webhook_url}")
|
||||
await self.logger.info(f" 公网地址: {webhook_url_public}")
|
||||
await self.logger.info(f"请在 Slack 后台配置此回调地址")
|
||||
except Exception as e:
|
||||
await self.logger.warning(f"无法生成 webhook URL: {e}")
|
||||
|
||||
async def keep_alive():
|
||||
async def shutdown_trigger_placeholder():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
await keep_alive()
|
||||
|
||||
await self.bot.run_task(
|
||||
host='0.0.0.0',
|
||||
port=self.config['port'],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -25,6 +25,13 @@ spec:
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
- name: port
|
||||
label:
|
||||
en_US: Port
|
||||
zh_Hans: 监听端口
|
||||
type: int
|
||||
required: true
|
||||
default: 2288
|
||||
execution:
|
||||
python:
|
||||
path: ./slack.py
|
||||
|
||||
@@ -58,6 +58,7 @@ class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
debug_messages: dict[str, list[dict]] = pydantic.Field(default_factory=dict, exclude=True)
|
||||
|
||||
ap: app.Application = pydantic.Field(exclude=True)
|
||||
ws_pool: typing.Any = pydantic.Field(exclude=True, default=None) # WebSocketConnectionPool
|
||||
|
||||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
|
||||
super().__init__(
|
||||
@@ -72,6 +73,15 @@ class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
self.bot_account_id = 'webchatbot'
|
||||
|
||||
self.debug_messages = {}
|
||||
self.ws_pool = None
|
||||
|
||||
def set_ws_pool(self, ws_pool):
|
||||
"""设置 WebSocket 连接池
|
||||
|
||||
Args:
|
||||
ws_pool: WebSocketConnectionPool 实例
|
||||
"""
|
||||
self.ws_pool = ws_pool
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
@@ -130,7 +140,7 @@ class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
quote_origin: bool = False,
|
||||
is_final: bool = False,
|
||||
) -> dict:
|
||||
"""回复消息"""
|
||||
"""Reply message chunk - supports both SSE (legacy) and WebSocket"""
|
||||
message_data = WebChatMessage(
|
||||
id=-1,
|
||||
role='assistant',
|
||||
@@ -139,24 +149,32 @@ class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
timestamp=datetime.now().isoformat(),
|
||||
)
|
||||
|
||||
# notify waiter
|
||||
session = (
|
||||
self.webchat_group_session
|
||||
if isinstance(message_source, platform_events.GroupMessage)
|
||||
else self.webchat_person_session
|
||||
)
|
||||
if message_source.message_chain.message_id not in session.resp_waiters:
|
||||
# session.resp_waiters[message_source.message_chain.message_id] = asyncio.Queue()
|
||||
queue = session.resp_queues[message_source.message_chain.message_id]
|
||||
# Determine session type
|
||||
if isinstance(message_source, platform_events.GroupMessage):
|
||||
session_type = 'group'
|
||||
session = self.webchat_group_session
|
||||
else: # FriendMessage
|
||||
session_type = 'person'
|
||||
session = self.webchat_person_session
|
||||
|
||||
# if isinstance(message_source, platform_events.FriendMessage):
|
||||
# queue = self.webchat_person_session.resp_queues[message_source.message_chain.message_id]
|
||||
# elif isinstance(message_source, platform_events.GroupMessage):
|
||||
# queue = self.webchat_group_session.resp_queues[message_source.message_chain.message_id]
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
message_data.is_final = True
|
||||
# print(message_data)
|
||||
await queue.put(message_data)
|
||||
# Legacy SSE support: put message into queue
|
||||
if message_source.message_chain.message_id in session.resp_queues:
|
||||
queue = session.resp_queues[message_source.message_chain.message_id]
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
message_data.is_final = True
|
||||
await queue.put(message_data)
|
||||
|
||||
# WebSocket support: broadcast to all connections
|
||||
if self.ws_pool:
|
||||
pipeline_uuid = self.ap.platform_mgr.webchat_proxy_bot.bot_entity.use_pipeline_uuid
|
||||
|
||||
# Determine event type
|
||||
event_type = 'message_complete' if (is_final and bot_message.tool_calls is None) else 'message_chunk'
|
||||
|
||||
# Broadcast to specified session only
|
||||
await self.ws_pool.broadcast_to_session(
|
||||
pipeline_uuid=pipeline_uuid, session_type=session_type, event_type=event_type, data=message_data.model_dump()
|
||||
)
|
||||
|
||||
return message_data.model_dump()
|
||||
|
||||
|
||||
@@ -132,7 +132,6 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
message_converter: WecomMessageConverter = WecomMessageConverter()
|
||||
event_converter: WecomEventConverter = WecomEventConverter()
|
||||
config: dict
|
||||
bot_uuid: str = None
|
||||
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
# 校验必填项
|
||||
@@ -143,12 +142,11 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
'EncodingAESKey',
|
||||
'contacts_secret',
|
||||
]
|
||||
|
||||
missing_keys = [key for key in required_keys if key not in config]
|
||||
if missing_keys:
|
||||
raise Exception(f'Wecom 缺少配置项: {missing_keys}')
|
||||
|
||||
# 创建运行时 bot 对象,始终使用统一 webhook 模式
|
||||
# 创建运行时 bot 对象
|
||||
bot = WecomClient(
|
||||
corpid=config['corpid'],
|
||||
secret=config['secret'],
|
||||
@@ -156,10 +154,9 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
EncodingAESKey=config['EncodingAESKey'],
|
||||
contacts_secret=config['contacts_secret'],
|
||||
logger=logger,
|
||||
unified_mode=True,
|
||||
)
|
||||
|
||||
|
||||
|
||||
super().__init__(
|
||||
config=config,
|
||||
logger=logger,
|
||||
@@ -167,9 +164,6 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot_account_id="",
|
||||
)
|
||||
|
||||
def set_bot_uuid(self, bot_uuid: str):
|
||||
"""设置 bot UUID(用于生成 webhook URL)"""
|
||||
self.bot_uuid = bot_uuid
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
@@ -189,7 +183,9 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
await self.bot.send_image(fixed_user_id, Wecom_event.agent_id, content['media_id'])
|
||||
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
|
||||
"""企业微信目前只有发送给个人的方法,
|
||||
构造target_id的方式为前半部分为账户id,后半部分为agent_id,中间使用“|”符号隔开。
|
||||
"""
|
||||
content_list = await WecomMessageConverter.yiri2target(message, self.bot)
|
||||
parts = target_id.split('|')
|
||||
user_id = parts[0]
|
||||
@@ -221,38 +217,16 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
elif event_type == platform_events.GroupMessage:
|
||||
pass
|
||||
|
||||
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
|
||||
"""处理统一 webhook 请求。
|
||||
|
||||
Args:
|
||||
bot_uuid: Bot 的 UUID
|
||||
path: 子路径(如果有的话)
|
||||
request: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self.bot.handle_unified_webhook(request)
|
||||
|
||||
async def run_async(self):
|
||||
|
||||
if self.bot_uuid and hasattr(self.logger, 'ap'):
|
||||
try:
|
||||
api_port = self.logger.ap.instance_config.data['api']['port']
|
||||
webhook_url = f"http://127.0.0.1:{api_port}/bots/{self.bot_uuid}"
|
||||
webhook_url_public = f"http://<Your-Public-IP>:{api_port}/bots/{self.bot_uuid}"
|
||||
|
||||
await self.logger.info(f"企业微信 Webhook 回调地址:")
|
||||
await self.logger.info(f" 本地地址: {webhook_url}")
|
||||
await self.logger.info(f" 公网地址: {webhook_url_public}")
|
||||
await self.logger.info(f"请在企业微信后台配置此回调地址")
|
||||
except Exception as e:
|
||||
await self.logger.warning(f"无法生成 webhook URL: {e}")
|
||||
|
||||
async def keep_alive():
|
||||
async def shutdown_trigger_placeholder():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
await keep_alive()
|
||||
|
||||
await self.bot.run_task(
|
||||
host=self.config['host'],
|
||||
port=self.config['port'],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -11,6 +11,23 @@ metadata:
|
||||
icon: wecom.png
|
||||
spec:
|
||||
config:
|
||||
- name: host
|
||||
label:
|
||||
en_US: Host
|
||||
zh_Hans: 监听主机
|
||||
description:
|
||||
en_US: Webhook host, unless you know what you're doing, please write 0.0.0.0
|
||||
zh_Hans: Webhook 监听主机,除非你知道自己在做什么,否则请写 0.0.0.0
|
||||
type: string
|
||||
required: true
|
||||
default: "0.0.0.0"
|
||||
- name: port
|
||||
label:
|
||||
en_US: Port
|
||||
zh_Hans: 监听端口
|
||||
type: integer
|
||||
required: true
|
||||
default: 2290
|
||||
- name: corpid
|
||||
label:
|
||||
en_US: Corpid
|
||||
|
||||
@@ -49,7 +49,7 @@ class WecomBotEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
return platform_events.FriendMessage(
|
||||
sender=platform_entities.Friend(
|
||||
id=event.userid,
|
||||
nickname=event.username,
|
||||
nickname='',
|
||||
remark='',
|
||||
),
|
||||
message_chain=message_chain,
|
||||
@@ -61,10 +61,10 @@ class WecomBotEventConverter(abstract_platform_adapter.AbstractEventConverter):
|
||||
sender = platform_entities.GroupMember(
|
||||
id=event.userid,
|
||||
permission='MEMBER',
|
||||
member_name=event.username,
|
||||
member_name=event.userid,
|
||||
group=platform_entities.Group(
|
||||
id=str(event.chatid),
|
||||
name=event.chatname,
|
||||
name='',
|
||||
permission=platform_entities.Permission.Member,
|
||||
),
|
||||
special_title='',
|
||||
@@ -88,22 +88,19 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
message_converter: WecomBotMessageConverter = WecomBotMessageConverter()
|
||||
event_converter: WecomBotEventConverter = WecomBotEventConverter()
|
||||
config: dict
|
||||
bot_uuid: str = None
|
||||
|
||||
def __init__(self, config: dict, logger: EventLogger):
|
||||
|
||||
required_keys = ['Token', 'EncodingAESKey', 'Corpid', 'BotId']
|
||||
required_keys = ['Token', 'EncodingAESKey', 'Corpid', 'BotId', 'port']
|
||||
missing_keys = [key for key in required_keys if key not in config]
|
||||
if missing_keys:
|
||||
raise Exception(f'WecomBot 缺少配置项: {missing_keys}')
|
||||
|
||||
|
||||
# 创建运行时 bot 对象
|
||||
bot = WecomBotClient(
|
||||
Token=config['Token'],
|
||||
EnCodingAESKey=config['EncodingAESKey'],
|
||||
Corpid=config['Corpid'],
|
||||
logger=logger,
|
||||
unified_mode=True,
|
||||
)
|
||||
bot_account_id = config['BotId']
|
||||
|
||||
@@ -120,50 +117,6 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
content = await self.message_converter.yiri2target(message)
|
||||
await self.bot.set_message(message_source.source_platform_object.message_id, content)
|
||||
|
||||
async def reply_message_chunk(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
bot_message,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
is_final: bool = False,
|
||||
):
|
||||
"""将流水线增量输出写入企业微信 stream 会话。
|
||||
|
||||
Args:
|
||||
message_source: 流水线提供的原始消息事件。
|
||||
bot_message: 当前片段对应的模型元信息(未使用)。
|
||||
message: 需要回复的消息链。
|
||||
quote_origin: 是否引用原消息(企业微信暂不支持)。
|
||||
is_final: 标记当前片段是否为最终回复。
|
||||
|
||||
Returns:
|
||||
dict: 包含 `stream` 键,标识写入是否成功。
|
||||
|
||||
Example:
|
||||
在流水线 `reply_message_chunk` 调用中自动触发,无需手动调用。
|
||||
"""
|
||||
# 转换为纯文本(智能机器人当前协议仅支持文本流)
|
||||
content = await self.message_converter.yiri2target(message)
|
||||
msg_id = message_source.source_platform_object.message_id
|
||||
|
||||
# 将片段推送到 WecomBotClient 中的队列,返回值用于判断是否走降级逻辑
|
||||
success = await self.bot.push_stream_chunk(msg_id, content, is_final=is_final)
|
||||
if not success and is_final:
|
||||
# 未命中流式队列时使用旧有 set_message 兜底
|
||||
await self.bot.set_message(msg_id, content)
|
||||
return {'stream': success}
|
||||
|
||||
async def is_stream_output_supported(self) -> bool:
|
||||
"""智能机器人侧默认开启流式能力。
|
||||
|
||||
Returns:
|
||||
bool: 恒定返回 True。
|
||||
|
||||
Example:
|
||||
流水线执行阶段会调用此方法以确认是否启用流式。"""
|
||||
return True
|
||||
|
||||
async def send_message(self, target_type, target_id, message):
|
||||
pass
|
||||
|
||||
@@ -185,46 +138,18 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
self.bot.on_message('group')(on_message)
|
||||
except Exception:
|
||||
print(traceback.format_exc())
|
||||
|
||||
def set_bot_uuid(self, bot_uuid: str):
|
||||
"""设置 bot UUID(用于生成 webhook URL)"""
|
||||
self.bot_uuid = bot_uuid
|
||||
|
||||
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
|
||||
"""处理统一 webhook 请求。
|
||||
|
||||
Args:
|
||||
bot_uuid: Bot 的 UUID
|
||||
path: 子路径(如果有的话)
|
||||
request: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self.bot.handle_unified_webhook(request)
|
||||
|
||||
|
||||
async def run_async(self):
|
||||
# 统一 webhook 模式下,不启动独立的 Quart 应用
|
||||
# 保持运行但不启动独立端口
|
||||
|
||||
# 打印 webhook 回调地址
|
||||
if self.bot_uuid and hasattr(self.logger, 'ap'):
|
||||
try:
|
||||
api_port = self.logger.ap.instance_config.data['api']['port']
|
||||
webhook_url = f"http://127.0.0.1:{api_port}/bots/{self.bot_uuid}"
|
||||
webhook_url_public = f"http://<Your-Public-IP>:{api_port}/bots/{self.bot_uuid}"
|
||||
|
||||
await self.logger.info(f"企业微信机器人 Webhook 回调地址:")
|
||||
await self.logger.info(f" 本地地址: {webhook_url}")
|
||||
await self.logger.info(f" 公网地址: {webhook_url_public}")
|
||||
await self.logger.info(f"请在企业微信后台配置此回调地址")
|
||||
except Exception as e:
|
||||
await self.logger.warning(f"无法生成 webhook URL: {e}")
|
||||
|
||||
async def keep_alive():
|
||||
async def shutdown_trigger_placeholder():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
await keep_alive()
|
||||
|
||||
await self.bot.run_task(
|
||||
host='0.0.0.0',
|
||||
port=self.config['port'],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -11,6 +11,13 @@ metadata:
|
||||
icon: wecombot.png
|
||||
spec:
|
||||
config:
|
||||
- name: port
|
||||
label:
|
||||
en_US: Port
|
||||
zh_Hans: 监听端口
|
||||
type: integer
|
||||
required: true
|
||||
default: 2291
|
||||
- name: Corpid
|
||||
label:
|
||||
en_US: Corpid
|
||||
|
||||
@@ -121,7 +121,6 @@ class WecomCSAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
bot: WecomCSClient = pydantic.Field(exclude=True)
|
||||
message_converter: WecomMessageConverter = WecomMessageConverter()
|
||||
event_converter: WecomEventConverter = WecomEventConverter()
|
||||
bot_uuid: str = None
|
||||
|
||||
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
|
||||
required_keys = [
|
||||
@@ -140,7 +139,6 @@ class WecomCSAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
token=config['token'],
|
||||
EncodingAESKey=config['EncodingAESKey'],
|
||||
logger=logger,
|
||||
unified_mode=True,
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
@@ -172,10 +170,6 @@ class WecomCSAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
pass
|
||||
|
||||
def set_bot_uuid(self, bot_uuid: str):
|
||||
"""设置 bot UUID(用于生成 webhook URL)"""
|
||||
self.bot_uuid = bot_uuid
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
@@ -196,41 +190,16 @@ class WecomCSAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
|
||||
elif event_type == platform_events.GroupMessage:
|
||||
pass
|
||||
|
||||
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
|
||||
"""处理统一 webhook 请求。
|
||||
|
||||
Args:
|
||||
bot_uuid: Bot 的 UUID
|
||||
path: 子路径(如果有的话)
|
||||
request: Quart Request 对象
|
||||
|
||||
Returns:
|
||||
响应数据
|
||||
"""
|
||||
return await self.bot.handle_unified_webhook(request)
|
||||
|
||||
async def run_async(self):
|
||||
# 统一 webhook 模式下,不启动独立的 Quart 应用
|
||||
# 保持运行但不启动独立端口
|
||||
|
||||
# 打印 webhook 回调地址
|
||||
if self.bot_uuid and hasattr(self.logger, 'ap'):
|
||||
try:
|
||||
api_port = self.logger.ap.instance_config.data['api']['port']
|
||||
webhook_url = f"http://127.0.0.1:{api_port}/bots/{self.bot_uuid}"
|
||||
webhook_url_public = f"http://<Your-Public-IP>:{api_port}/bots/{self.bot_uuid}"
|
||||
|
||||
await self.logger.info(f"企业微信客服 Webhook 回调地址:")
|
||||
await self.logger.info(f" 本地地址: {webhook_url}")
|
||||
await self.logger.info(f" 公网地址: {webhook_url_public}")
|
||||
await self.logger.info(f"请在企业微信后台配置此回调地址")
|
||||
except Exception as e:
|
||||
await self.logger.warning(f"无法生成 webhook URL: {e}")
|
||||
|
||||
async def keep_alive():
|
||||
async def shutdown_trigger_placeholder():
|
||||
while True:
|
||||
await asyncio.sleep(1)
|
||||
await keep_alive()
|
||||
|
||||
await self.bot.run_task(
|
||||
host='0.0.0.0',
|
||||
port=self.config['port'],
|
||||
shutdown_trigger=shutdown_trigger_placeholder,
|
||||
)
|
||||
|
||||
async def kill(self) -> bool:
|
||||
return False
|
||||
|
||||
@@ -11,6 +11,13 @@ metadata:
|
||||
icon: wecom.png
|
||||
spec:
|
||||
config:
|
||||
- name: port
|
||||
label:
|
||||
en_US: Port
|
||||
zh_Hans: 监听端口
|
||||
type: int
|
||||
required: true
|
||||
default: 2289
|
||||
- name: corpid
|
||||
label:
|
||||
en_US: Corpid
|
||||
|
||||
@@ -4,7 +4,6 @@ import json
|
||||
from typing import List
|
||||
from pkg.rag.knowledge.services import base_service
|
||||
from pkg.core import app
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
|
||||
|
||||
class Chunker(base_service.BaseService):
|
||||
@@ -28,6 +27,21 @@ class Chunker(base_service.BaseService):
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
# words = text.split()
|
||||
# chunks = []
|
||||
# current_chunk = []
|
||||
|
||||
# for word in words:
|
||||
# current_chunk.append(word)
|
||||
# if len(current_chunk) > self.chunk_size:
|
||||
# chunks.append(" ".join(current_chunk[:self.chunk_size]))
|
||||
# current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:]
|
||||
|
||||
# if current_chunk:
|
||||
# chunks.append(" ".join(current_chunk))
|
||||
|
||||
# A more robust chunking strategy (e.g., using recursive character text splitter)
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=self.chunk_size,
|
||||
|
||||
@@ -60,7 +60,6 @@ dependencies = [
|
||||
"ebooklib>=0.18",
|
||||
"html2text>=2024.2.26",
|
||||
"langchain>=0.2.0",
|
||||
"langchain-text-splitters>=0.0.1",
|
||||
"chromadb>=0.4.24",
|
||||
"qdrant-client (>=1.15.1,<2.0.0)",
|
||||
"langbot-plugin==0.1.4",
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useEffect, useState } from 'react';
|
||||
import { useEffect, useState } from 'react';
|
||||
import {
|
||||
IChooseAdapterEntity,
|
||||
IPipelineEntity,
|
||||
@@ -112,87 +112,12 @@ export default function BotForm({
|
||||
IDynamicFormItemSchema[]
|
||||
>([]);
|
||||
const [, setIsLoading] = useState<boolean>(false);
|
||||
const [webhookUrl, setWebhookUrl] = useState<string>('');
|
||||
const webhookInputRef = React.useRef<HTMLInputElement>(null);
|
||||
|
||||
useEffect(() => {
|
||||
setBotFormValues();
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, []);
|
||||
|
||||
// 复制到剪贴板的辅助函数 - 使用页面上的真实input元素
|
||||
const copyToClipboard = () => {
|
||||
console.log('[Copy] Attempting to copy from input element');
|
||||
|
||||
const inputElement = webhookInputRef.current;
|
||||
if (!inputElement) {
|
||||
console.error('[Copy] Input element not found');
|
||||
toast.error(t('common.copyFailed'));
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
// 确保input元素可见且未被禁用
|
||||
inputElement.disabled = false;
|
||||
inputElement.readOnly = false;
|
||||
|
||||
// 聚焦并选中所有文本
|
||||
inputElement.focus();
|
||||
inputElement.select();
|
||||
|
||||
// 尝试使用现代API
|
||||
if (navigator.clipboard && navigator.clipboard.writeText) {
|
||||
console.log(
|
||||
'[Copy] Using Clipboard API with input value:',
|
||||
inputElement.value,
|
||||
);
|
||||
navigator.clipboard
|
||||
.writeText(inputElement.value)
|
||||
.then(() => {
|
||||
console.log('[Copy] Clipboard API success');
|
||||
inputElement.blur(); // 取消选中
|
||||
inputElement.readOnly = true;
|
||||
toast.success(t('bots.webhookUrlCopied'));
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error(
|
||||
'[Copy] Clipboard API failed, trying execCommand:',
|
||||
err,
|
||||
);
|
||||
// 降级到execCommand
|
||||
const successful = document.execCommand('copy');
|
||||
console.log('[Copy] execCommand result:', successful);
|
||||
inputElement.blur();
|
||||
inputElement.readOnly = true;
|
||||
if (successful) {
|
||||
toast.success(t('bots.webhookUrlCopied'));
|
||||
} else {
|
||||
toast.error(t('common.copyFailed'));
|
||||
}
|
||||
});
|
||||
} else {
|
||||
// 直接使用execCommand
|
||||
console.log(
|
||||
'[Copy] Using execCommand with input value:',
|
||||
inputElement.value,
|
||||
);
|
||||
const successful = document.execCommand('copy');
|
||||
console.log('[Copy] execCommand result:', successful);
|
||||
inputElement.blur();
|
||||
inputElement.readOnly = true;
|
||||
if (successful) {
|
||||
toast.success(t('bots.webhookUrlCopied'));
|
||||
} else {
|
||||
toast.error(t('common.copyFailed'));
|
||||
}
|
||||
}
|
||||
} catch (err) {
|
||||
console.error('[Copy] Copy failed:', err);
|
||||
inputElement.readOnly = true;
|
||||
toast.error(t('common.copyFailed'));
|
||||
}
|
||||
};
|
||||
|
||||
function setBotFormValues() {
|
||||
initBotFormComponent().then(() => {
|
||||
// 拉取初始化表单信息
|
||||
@@ -208,20 +133,12 @@ export default function BotForm({
|
||||
console.log('form', form.getValues());
|
||||
handleAdapterSelect(val.adapter);
|
||||
// dynamicForm.setFieldsValue(val.adapter_config);
|
||||
|
||||
// 设置 webhook 地址(如果有)
|
||||
if (val.webhook_full_url) {
|
||||
setWebhookUrl(val.webhook_full_url);
|
||||
} else {
|
||||
setWebhookUrl('');
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
toast.error(t('bots.getBotConfigError') + err.message);
|
||||
});
|
||||
} else {
|
||||
form.reset();
|
||||
setWebhookUrl('');
|
||||
}
|
||||
});
|
||||
}
|
||||
@@ -296,7 +213,7 @@ export default function BotForm({
|
||||
|
||||
async function getBotConfig(
|
||||
botId: string,
|
||||
): Promise<z.infer<typeof formSchema> & { webhook_full_url?: string }> {
|
||||
): Promise<z.infer<typeof formSchema>> {
|
||||
return new Promise((resolve, reject) => {
|
||||
httpClient
|
||||
.getBot(botId)
|
||||
@@ -309,10 +226,6 @@ export default function BotForm({
|
||||
adapter_config: bot.adapter_config,
|
||||
enable: bot.enable ?? true,
|
||||
use_pipeline_uuid: bot.use_pipeline_uuid ?? '',
|
||||
webhook_full_url: bot.adapter_runtime_values
|
||||
? ((bot.adapter_runtime_values as Record<string, unknown>)
|
||||
.webhook_full_url as string)
|
||||
: undefined,
|
||||
});
|
||||
})
|
||||
.catch((err) => {
|
||||
@@ -456,86 +369,51 @@ export default function BotForm({
|
||||
<div className="space-y-4">
|
||||
{/* 是否启用 & 绑定流水线 仅在编辑模式 */}
|
||||
{initBotId && (
|
||||
<>
|
||||
<div className="flex items-center gap-6">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="enable"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex flex-col justify-start gap-[0.8rem] h-[3.8rem]">
|
||||
<FormLabel>{t('common.enable')}</FormLabel>
|
||||
<FormControl>
|
||||
<Switch
|
||||
checked={field.value}
|
||||
onCheckedChange={field.onChange}
|
||||
/>
|
||||
</FormControl>
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
<div className="flex items-center gap-6">
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="enable"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex flex-col justify-start gap-[0.8rem] h-[3.8rem]">
|
||||
<FormLabel>{t('common.enable')}</FormLabel>
|
||||
<FormControl>
|
||||
<Switch
|
||||
checked={field.value}
|
||||
onCheckedChange={field.onChange}
|
||||
/>
|
||||
</FormControl>
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="use_pipeline_uuid"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex flex-col justify-start gap-[0.8rem] h-[3.8rem]">
|
||||
<FormLabel>{t('bots.bindPipeline')}</FormLabel>
|
||||
<FormControl>
|
||||
<Select onValueChange={field.onChange} {...field}>
|
||||
<SelectTrigger className="bg-[#ffffff] dark:bg-[#2a2a2e]">
|
||||
<SelectValue
|
||||
placeholder={t('bots.selectPipeline')}
|
||||
/>
|
||||
</SelectTrigger>
|
||||
<SelectContent className="fixed z-[1000]">
|
||||
<SelectGroup>
|
||||
{pipelineNameList.map((item) => (
|
||||
<SelectItem
|
||||
key={item.value}
|
||||
value={item.value}
|
||||
>
|
||||
{item.label}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</FormControl>
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* Webhook 地址显示(统一 Webhook 模式) */}
|
||||
{webhookUrl && (
|
||||
<FormItem>
|
||||
<FormLabel>{t('bots.webhookUrl')}</FormLabel>
|
||||
<div className="flex items-center gap-2">
|
||||
<Input
|
||||
ref={webhookInputRef}
|
||||
value={webhookUrl}
|
||||
readOnly
|
||||
className="flex-1 bg-gray-50 dark:bg-gray-900"
|
||||
onClick={(e) => {
|
||||
// 点击输入框时自动全选
|
||||
(e.target as HTMLInputElement).select();
|
||||
}}
|
||||
/>
|
||||
<Button
|
||||
type="button"
|
||||
variant="outline"
|
||||
size="sm"
|
||||
onClick={copyToClipboard}
|
||||
>
|
||||
{t('common.copy')}
|
||||
</Button>
|
||||
</div>
|
||||
<p className="text-sm text-gray-500 mt-1">
|
||||
{t('bots.webhookUrlHint')}
|
||||
</p>
|
||||
</FormItem>
|
||||
)}
|
||||
</>
|
||||
<FormField
|
||||
control={form.control}
|
||||
name="use_pipeline_uuid"
|
||||
render={({ field }) => (
|
||||
<FormItem className="flex flex-col justify-start gap-[0.8rem] h-[3.8rem]">
|
||||
<FormLabel>{t('bots.bindPipeline')}</FormLabel>
|
||||
<FormControl>
|
||||
<Select onValueChange={field.onChange} {...field}>
|
||||
<SelectTrigger className="bg-[#ffffff] dark:bg-[#2a2a2e]">
|
||||
<SelectValue
|
||||
placeholder={t('bots.selectPipeline')}
|
||||
/>
|
||||
</SelectTrigger>
|
||||
<SelectContent className="fixed z-[1000]">
|
||||
<SelectGroup>
|
||||
{pipelineNameList.map((item) => (
|
||||
<SelectItem key={item.value} value={item.value}>
|
||||
{item.label}
|
||||
</SelectItem>
|
||||
))}
|
||||
</SelectGroup>
|
||||
</SelectContent>
|
||||
</Select>
|
||||
</FormControl>
|
||||
</FormItem>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<FormField
|
||||
|
||||
@@ -11,6 +11,10 @@ import { Message } from '@/app/infra/entities/message';
|
||||
import { toast } from 'sonner';
|
||||
import AtBadge from './AtBadge';
|
||||
import { Switch } from '@/components/ui/switch';
|
||||
import {
|
||||
PipelineWebSocketClient,
|
||||
SessionType,
|
||||
} from '@/app/infra/websocket/PipelineWebSocketClient';
|
||||
|
||||
interface MessageComponent {
|
||||
type: 'At' | 'Plain';
|
||||
@@ -31,17 +35,27 @@ export default function DebugDialog({
|
||||
}: DebugDialogProps) {
|
||||
const { t } = useTranslation();
|
||||
const [selectedPipelineId, setSelectedPipelineId] = useState(pipelineId);
|
||||
const [sessionType, setSessionType] = useState<'person' | 'group'>('person');
|
||||
const [sessionType, setSessionType] = useState<SessionType>('person');
|
||||
const [messages, setMessages] = useState<Message[]>([]);
|
||||
const [inputValue, setInputValue] = useState('');
|
||||
const [showAtPopover, setShowAtPopover] = useState(false);
|
||||
const [hasAt, setHasAt] = useState(false);
|
||||
const [isHovering, setIsHovering] = useState(false);
|
||||
const [isStreaming, setIsStreaming] = useState(true);
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const popoverRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// WebSocket states
|
||||
const [wsClient, setWsClient] = useState<PipelineWebSocketClient | null>(
|
||||
null,
|
||||
);
|
||||
const [connectionStatus, setConnectionStatus] = useState<
|
||||
'disconnected' | 'connecting' | 'connected'
|
||||
>('disconnected');
|
||||
const [pendingMessages, setPendingMessages] = useState<
|
||||
Map<string, Message>
|
||||
>(new Map());
|
||||
|
||||
const scrollToBottom = useCallback(() => {
|
||||
// 使用setTimeout确保在DOM更新后执行滚动
|
||||
setTimeout(() => {
|
||||
@@ -57,37 +71,161 @@ export default function DebugDialog({
|
||||
}, 0);
|
||||
}, []);
|
||||
|
||||
const loadMessages = useCallback(
|
||||
async (pipelineId: string) => {
|
||||
try {
|
||||
const response = await httpClient.getWebChatHistoryMessages(
|
||||
pipelineId,
|
||||
sessionType,
|
||||
);
|
||||
setMessages(response.messages);
|
||||
} catch (error) {
|
||||
console.error('Failed to load messages:', error);
|
||||
}
|
||||
},
|
||||
[sessionType],
|
||||
);
|
||||
// 在useEffect中监听messages变化时滚动
|
||||
// Scroll to bottom when messages change
|
||||
useEffect(() => {
|
||||
scrollToBottom();
|
||||
}, [messages, scrollToBottom]);
|
||||
|
||||
// WebSocket connection setup
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
setSelectedPipelineId(pipelineId);
|
||||
loadMessages(pipelineId);
|
||||
}
|
||||
}, [open, pipelineId]);
|
||||
if (!open) return;
|
||||
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
loadMessages(selectedPipelineId);
|
||||
}
|
||||
}, [sessionType, selectedPipelineId, open, loadMessages]);
|
||||
const client = new PipelineWebSocketClient(
|
||||
selectedPipelineId,
|
||||
sessionType,
|
||||
);
|
||||
|
||||
// Setup event handlers
|
||||
client.onConnected = (data) => {
|
||||
console.log('[DebugDialog] WebSocket connected:', data);
|
||||
setConnectionStatus('connected');
|
||||
// Load history messages after connection
|
||||
client.loadHistory();
|
||||
};
|
||||
|
||||
client.onHistory = (data) => {
|
||||
console.log('[DebugDialog] History loaded:', data?.messages.length);
|
||||
if (data) {
|
||||
setMessages(data.messages);
|
||||
}
|
||||
};
|
||||
|
||||
client.onMessageSent = (data) => {
|
||||
console.log('[DebugDialog] Message sent confirmed:', data);
|
||||
if (data) {
|
||||
// Update client message ID to server message ID
|
||||
const clientMsgId = data.client_message_id;
|
||||
const serverMsgId = data.server_message_id;
|
||||
|
||||
setMessages((prev) =>
|
||||
prev.map((msg) =>
|
||||
msg.id === -1 && pendingMessages.has(clientMsgId)
|
||||
? { ...msg, id: serverMsgId }
|
||||
: msg,
|
||||
),
|
||||
);
|
||||
|
||||
setPendingMessages((prev) => {
|
||||
const newMap = new Map(prev);
|
||||
newMap.delete(clientMsgId);
|
||||
return newMap;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
client.onMessageStart = (data) => {
|
||||
console.log('[DebugDialog] Message start:', data);
|
||||
if (data) {
|
||||
const placeholderMessage: Message = {
|
||||
id: data.message_id,
|
||||
role: 'assistant',
|
||||
content: '',
|
||||
message_chain: [],
|
||||
timestamp: data.timestamp,
|
||||
};
|
||||
setMessages((prev) => [...prev, placeholderMessage]);
|
||||
}
|
||||
};
|
||||
|
||||
client.onMessageChunk = (data) => {
|
||||
if (data) {
|
||||
// Update streaming message (content is cumulative)
|
||||
setMessages((prev) =>
|
||||
prev.map((msg) =>
|
||||
msg.id === data.message_id
|
||||
? {
|
||||
...msg,
|
||||
content: data.content,
|
||||
message_chain: data.message_chain,
|
||||
}
|
||||
: msg,
|
||||
),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
client.onMessageComplete = (data) => {
|
||||
console.log('[DebugDialog] Message complete:', data);
|
||||
if (data) {
|
||||
// Mark message as complete
|
||||
setMessages((prev) =>
|
||||
prev.map((msg) =>
|
||||
msg.id === data.message_id
|
||||
? {
|
||||
...msg,
|
||||
content: data.final_content,
|
||||
message_chain: data.message_chain,
|
||||
}
|
||||
: msg,
|
||||
),
|
||||
);
|
||||
}
|
||||
};
|
||||
|
||||
client.onMessageError = (data) => {
|
||||
console.error('[DebugDialog] Message error:', data);
|
||||
if (data) {
|
||||
toast.error(`Message error: ${data.error}`);
|
||||
}
|
||||
};
|
||||
|
||||
client.onPluginMessage = (data) => {
|
||||
console.log('[DebugDialog] Plugin message:', data);
|
||||
if (data) {
|
||||
const pluginMessage: Message = {
|
||||
id: data.message_id,
|
||||
role: 'assistant',
|
||||
content: data.content,
|
||||
message_chain: data.message_chain,
|
||||
timestamp: data.timestamp,
|
||||
};
|
||||
setMessages((prev) => [...prev, pluginMessage]);
|
||||
}
|
||||
};
|
||||
|
||||
client.onError = (data) => {
|
||||
console.error('[DebugDialog] WebSocket error:', data);
|
||||
if (data) {
|
||||
toast.error(`WebSocket error: ${data.error}`);
|
||||
}
|
||||
};
|
||||
|
||||
client.onDisconnected = () => {
|
||||
console.log('[DebugDialog] WebSocket disconnected');
|
||||
setConnectionStatus('disconnected');
|
||||
};
|
||||
|
||||
// Connect to WebSocket
|
||||
setConnectionStatus('connecting');
|
||||
client
|
||||
.connect(httpClient.getSessionSync())
|
||||
.then(() => {
|
||||
console.log('[DebugDialog] WebSocket connection established');
|
||||
})
|
||||
.catch((err) => {
|
||||
console.error('[DebugDialog] Failed to connect WebSocket:', err);
|
||||
toast.error('Failed to connect to server');
|
||||
setConnectionStatus('disconnected');
|
||||
});
|
||||
|
||||
setWsClient(client);
|
||||
|
||||
// Cleanup on unmount or session type change
|
||||
return () => {
|
||||
console.log('[DebugDialog] Cleaning up WebSocket connection');
|
||||
client.disconnect();
|
||||
};
|
||||
}, [open, selectedPipelineId, sessionType]); // Reconnect when session type changes
|
||||
|
||||
useEffect(() => {
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
@@ -150,6 +288,12 @@ export default function DebugDialog({
|
||||
const sendMessage = async () => {
|
||||
if (!inputValue.trim() && !hasAt) return;
|
||||
|
||||
// Check WebSocket connection
|
||||
if (!wsClient || connectionStatus !== 'connected') {
|
||||
toast.error('Not connected to server');
|
||||
return;
|
||||
}
|
||||
|
||||
try {
|
||||
const messageChain = [];
|
||||
|
||||
@@ -170,7 +314,7 @@ export default function DebugDialog({
|
||||
});
|
||||
|
||||
if (hasAt) {
|
||||
// for showing
|
||||
// For display
|
||||
text_content = '@webchatbot' + text_content;
|
||||
}
|
||||
|
||||
@@ -181,97 +325,26 @@ export default function DebugDialog({
|
||||
timestamp: new Date().toISOString(),
|
||||
message_chain: messageChain,
|
||||
};
|
||||
// 根据isStreaming状态决定使用哪种传输方式
|
||||
if (isStreaming) {
|
||||
// streaming
|
||||
// 创建初始bot消息
|
||||
const placeholderRandomId = Math.floor(Math.random() * 1000000);
|
||||
const botMessagePlaceholder: Message = {
|
||||
id: placeholderRandomId,
|
||||
role: 'assistant',
|
||||
content: 'Generating...',
|
||||
timestamp: new Date().toISOString(),
|
||||
message_chain: [{ type: 'Plain', text: 'Generating...' }],
|
||||
};
|
||||
|
||||
// 添加用户消息和初始bot消息到状态
|
||||
// Add user message to UI immediately
|
||||
setMessages((prevMessages) => [...prevMessages, userMessage]);
|
||||
setInputValue('');
|
||||
setHasAt(false);
|
||||
|
||||
setMessages((prevMessages) => [
|
||||
...prevMessages,
|
||||
userMessage,
|
||||
botMessagePlaceholder,
|
||||
]);
|
||||
setInputValue('');
|
||||
setHasAt(false);
|
||||
try {
|
||||
await httpClient.sendStreamingWebChatMessage(
|
||||
sessionType,
|
||||
messageChain,
|
||||
selectedPipelineId,
|
||||
(data) => {
|
||||
// 处理流式响应数据
|
||||
console.log('data', data);
|
||||
if (data.message) {
|
||||
// 更新完整内容
|
||||
// Send via WebSocket
|
||||
const clientMessageId = wsClient.sendMessage(messageChain);
|
||||
|
||||
setMessages((prevMessages) => {
|
||||
const updatedMessages = [...prevMessages];
|
||||
const botMessageIndex = updatedMessages.findIndex(
|
||||
(message) => message.id === placeholderRandomId,
|
||||
);
|
||||
if (botMessageIndex !== -1) {
|
||||
updatedMessages[botMessageIndex] = {
|
||||
...updatedMessages[botMessageIndex],
|
||||
content: data.message.content,
|
||||
message_chain: [
|
||||
{ type: 'Plain', text: data.message.content },
|
||||
],
|
||||
};
|
||||
}
|
||||
return updatedMessages;
|
||||
});
|
||||
}
|
||||
},
|
||||
() => {},
|
||||
(error) => {
|
||||
// 处理错误
|
||||
console.error('Streaming error:', error);
|
||||
if (sessionType === 'person') {
|
||||
toast.error(t('pipelines.debugDialog.sendFailed'));
|
||||
}
|
||||
},
|
||||
);
|
||||
} catch (error) {
|
||||
console.error('Failed to send streaming message:', error);
|
||||
if (sessionType === 'person') {
|
||||
toast.error(t('pipelines.debugDialog.sendFailed'));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// non-streaming
|
||||
setMessages((prevMessages) => [...prevMessages, userMessage]);
|
||||
setInputValue('');
|
||||
setHasAt(false);
|
||||
// Track pending message for ID mapping
|
||||
setPendingMessages((prev) => {
|
||||
const newMap = new Map(prev);
|
||||
newMap.set(clientMessageId, userMessage);
|
||||
return newMap;
|
||||
});
|
||||
|
||||
const response = await httpClient.sendWebChatMessage(
|
||||
sessionType,
|
||||
messageChain,
|
||||
selectedPipelineId,
|
||||
180000,
|
||||
);
|
||||
|
||||
setMessages((prevMessages) => [...prevMessages, response.message]);
|
||||
}
|
||||
} catch (
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
error: any
|
||||
) {
|
||||
console.log(error, 'type of error', typeof error);
|
||||
console.error('Failed to send message:', error);
|
||||
|
||||
if (!error.message.includes('timeout') && sessionType === 'person') {
|
||||
toast.error(t('pipelines.debugDialog.sendFailed'));
|
||||
}
|
||||
console.log('[DebugDialog] Message sent:', clientMessageId);
|
||||
} catch (error) {
|
||||
console.error('[DebugDialog] Failed to send message:', error);
|
||||
toast.error(t('pipelines.debugDialog.sendFailed'));
|
||||
} finally {
|
||||
inputRef.current?.focus();
|
||||
}
|
||||
@@ -390,12 +463,6 @@ export default function DebugDialog({
|
||||
</ScrollArea>
|
||||
|
||||
<div className="p-4 pb-0 bg-white dark:bg-black flex gap-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-sm text-gray-600">
|
||||
{t('pipelines.debugDialog.streaming')}
|
||||
</span>
|
||||
<Switch checked={isStreaming} onCheckedChange={setIsStreaming} />
|
||||
</div>
|
||||
<div className="flex-1 flex items-center gap-2">
|
||||
{hasAt && (
|
||||
<AtBadge targetName="webchatbot" onRemove={handleAtRemove} />
|
||||
|
||||
@@ -141,7 +141,6 @@ export interface Bot {
|
||||
use_pipeline_uuid?: string;
|
||||
created_at?: string;
|
||||
updated_at?: string;
|
||||
adapter_runtime_values?: object;
|
||||
}
|
||||
|
||||
export interface ApiRespKnowledgeBases {
|
||||
|
||||
394
web/src/app/infra/websocket/PipelineWebSocketClient.ts
Normal file
394
web/src/app/infra/websocket/PipelineWebSocketClient.ts
Normal file
@@ -0,0 +1,394 @@
|
||||
/**
|
||||
* Pipeline WebSocket Client
|
||||
*
|
||||
* Provides real-time bidirectional communication for pipeline debugging.
|
||||
* Supports person and group session isolation.
|
||||
*/
|
||||
|
||||
import { Message } from '@/app/infra/entities/message';
|
||||
|
||||
export type SessionType = 'person' | 'group';
|
||||
|
||||
export interface WebSocketEventData {
|
||||
// Connected event
|
||||
connected?: {
|
||||
connection_id: string;
|
||||
session_type: SessionType;
|
||||
pipeline_uuid: string;
|
||||
};
|
||||
|
||||
// History event
|
||||
history?: {
|
||||
messages: Message[];
|
||||
has_more: boolean;
|
||||
};
|
||||
|
||||
// Message sent confirmation
|
||||
message_sent?: {
|
||||
client_message_id: string;
|
||||
server_message_id: number;
|
||||
timestamp: string;
|
||||
};
|
||||
|
||||
// Message start
|
||||
message_start?: {
|
||||
message_id: number;
|
||||
role: 'assistant';
|
||||
timestamp: string;
|
||||
reply_to: number;
|
||||
};
|
||||
|
||||
// Message chunk
|
||||
message_chunk?: {
|
||||
message_id: number;
|
||||
content: string;
|
||||
message_chain: object[];
|
||||
timestamp: string;
|
||||
};
|
||||
|
||||
// Message complete
|
||||
message_complete?: {
|
||||
message_id: number;
|
||||
final_content: string;
|
||||
message_chain: object[];
|
||||
timestamp: string;
|
||||
};
|
||||
|
||||
// Message error
|
||||
message_error?: {
|
||||
message_id: number;
|
||||
error: string;
|
||||
error_code?: string;
|
||||
};
|
||||
|
||||
// Interrupted
|
||||
interrupted?: {
|
||||
message_id: number;
|
||||
partial_content: string;
|
||||
};
|
||||
|
||||
// Plugin message
|
||||
plugin_message?: {
|
||||
message_id: number;
|
||||
role: 'assistant';
|
||||
content: string;
|
||||
message_chain: object[];
|
||||
timestamp: string;
|
||||
source: 'plugin';
|
||||
};
|
||||
|
||||
// Error
|
||||
error?: {
|
||||
error: string;
|
||||
error_code: string;
|
||||
details?: object;
|
||||
};
|
||||
|
||||
// Pong
|
||||
pong?: {
|
||||
timestamp: number;
|
||||
};
|
||||
}
|
||||
|
||||
export class PipelineWebSocketClient {
|
||||
private ws: WebSocket | null = null;
|
||||
private pipelineId: string;
|
||||
private sessionType: SessionType;
|
||||
private reconnectAttempts = 0;
|
||||
private maxReconnectAttempts = 5;
|
||||
private pingInterval: NodeJS.Timeout | null = null;
|
||||
private reconnectTimeout: NodeJS.Timeout | null = null;
|
||||
private isManualDisconnect = false;
|
||||
|
||||
// Event callbacks
|
||||
public onConnected?: (data: WebSocketEventData['connected']) => void;
|
||||
public onHistory?: (data: WebSocketEventData['history']) => void;
|
||||
public onMessageSent?: (data: WebSocketEventData['message_sent']) => void;
|
||||
public onMessageStart?: (data: WebSocketEventData['message_start']) => void;
|
||||
public onMessageChunk?: (data: WebSocketEventData['message_chunk']) => void;
|
||||
public onMessageComplete?: (
|
||||
data: WebSocketEventData['message_complete'],
|
||||
) => void;
|
||||
public onMessageError?: (data: WebSocketEventData['message_error']) => void;
|
||||
public onInterrupted?: (data: WebSocketEventData['interrupted']) => void;
|
||||
public onPluginMessage?: (data: WebSocketEventData['plugin_message']) => void;
|
||||
public onError?: (data: WebSocketEventData['error']) => void;
|
||||
public onDisconnected?: () => void;
|
||||
|
||||
constructor(pipelineId: string, sessionType: SessionType) {
|
||||
this.pipelineId = pipelineId;
|
||||
this.sessionType = sessionType;
|
||||
}
|
||||
|
||||
/**
|
||||
* Connect to WebSocket server
|
||||
*/
|
||||
connect(token: string): Promise<void> {
|
||||
return new Promise((resolve, reject) => {
|
||||
this.isManualDisconnect = false;
|
||||
|
||||
const wsUrl = this.buildWebSocketUrl();
|
||||
console.log(`[WebSocket] Connecting to ${wsUrl}...`);
|
||||
|
||||
try {
|
||||
this.ws = new WebSocket(wsUrl);
|
||||
} catch (error) {
|
||||
console.error('[WebSocket] Failed to create WebSocket:', error);
|
||||
reject(error);
|
||||
return;
|
||||
}
|
||||
|
||||
this.ws.onopen = () => {
|
||||
console.log('[WebSocket] Connection opened');
|
||||
|
||||
// Send connect event with session type and token
|
||||
this.send('connect', {
|
||||
pipeline_uuid: this.pipelineId,
|
||||
session_type: this.sessionType,
|
||||
token,
|
||||
});
|
||||
|
||||
// Start ping interval
|
||||
this.startPing();
|
||||
|
||||
// Reset reconnect attempts on successful connection
|
||||
this.reconnectAttempts = 0;
|
||||
|
||||
resolve();
|
||||
};
|
||||
|
||||
this.ws.onmessage = (event) => {
|
||||
this.handleMessage(event);
|
||||
};
|
||||
|
||||
this.ws.onerror = (error) => {
|
||||
console.error('[WebSocket] Error:', error);
|
||||
reject(error);
|
||||
};
|
||||
|
||||
this.ws.onclose = (event) => {
|
||||
console.log(
|
||||
`[WebSocket] Connection closed: code=${event.code}, reason=${event.reason}`,
|
||||
);
|
||||
this.handleDisconnect();
|
||||
};
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle incoming WebSocket message
|
||||
*/
|
||||
private handleMessage(event: MessageEvent) {
|
||||
try {
|
||||
const message = JSON.parse(event.data);
|
||||
const { type, data } = message;
|
||||
|
||||
console.log(`[WebSocket] Received: ${type}`, data);
|
||||
|
||||
switch (type) {
|
||||
case 'connected':
|
||||
this.onConnected?.(data);
|
||||
break;
|
||||
case 'history':
|
||||
this.onHistory?.(data);
|
||||
break;
|
||||
case 'message_sent':
|
||||
this.onMessageSent?.(data);
|
||||
break;
|
||||
case 'message_start':
|
||||
this.onMessageStart?.(data);
|
||||
break;
|
||||
case 'message_chunk':
|
||||
this.onMessageChunk?.(data);
|
||||
break;
|
||||
case 'message_complete':
|
||||
this.onMessageComplete?.(data);
|
||||
break;
|
||||
case 'message_error':
|
||||
this.onMessageError?.(data);
|
||||
break;
|
||||
case 'interrupted':
|
||||
this.onInterrupted?.(data);
|
||||
break;
|
||||
case 'plugin_message':
|
||||
this.onPluginMessage?.(data);
|
||||
break;
|
||||
case 'error':
|
||||
this.onError?.(data);
|
||||
break;
|
||||
case 'pong':
|
||||
// Heartbeat response, no action needed
|
||||
break;
|
||||
default:
|
||||
console.warn(`[WebSocket] Unknown message type: ${type}`);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('[WebSocket] Failed to parse message:', error);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Send message to server
|
||||
*/
|
||||
sendMessage(messageChain: object[]): string {
|
||||
const clientMessageId = this.generateMessageId();
|
||||
this.send('send_message', {
|
||||
message_chain: messageChain,
|
||||
client_message_id: clientMessageId,
|
||||
});
|
||||
return clientMessageId;
|
||||
}
|
||||
|
||||
/**
|
||||
* Load history messages
|
||||
*/
|
||||
loadHistory(beforeMessageId?: number, limit?: number) {
|
||||
this.send('load_history', {
|
||||
before_message_id: beforeMessageId,
|
||||
limit,
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Interrupt streaming message
|
||||
*/
|
||||
interrupt(messageId: number) {
|
||||
this.send('interrupt', { message_id: messageId });
|
||||
}
|
||||
|
||||
/**
|
||||
* Send event to server
|
||||
*/
|
||||
private send(type: string, data: object) {
|
||||
if (this.ws?.readyState === WebSocket.OPEN) {
|
||||
const message = JSON.stringify({ type, data });
|
||||
console.log(`[WebSocket] Sending: ${type}`, data);
|
||||
this.ws.send(message);
|
||||
} else {
|
||||
console.warn(
|
||||
`[WebSocket] Cannot send message, connection not open (state: ${this.ws?.readyState})`,
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Start ping interval (heartbeat)
|
||||
*/
|
||||
private startPing() {
|
||||
this.stopPing();
|
||||
this.pingInterval = setInterval(() => {
|
||||
this.send('ping', { timestamp: Date.now() });
|
||||
}, 30000); // Ping every 30 seconds
|
||||
}
|
||||
|
||||
/**
|
||||
* Stop ping interval
|
||||
*/
|
||||
private stopPing() {
|
||||
if (this.pingInterval) {
|
||||
clearInterval(this.pingInterval);
|
||||
this.pingInterval = null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Handle disconnection
|
||||
*/
|
||||
private handleDisconnect() {
|
||||
this.stopPing();
|
||||
this.onDisconnected?.();
|
||||
|
||||
// Auto reconnect if not manual disconnect
|
||||
if (
|
||||
!this.isManualDisconnect &&
|
||||
this.reconnectAttempts < this.maxReconnectAttempts
|
||||
) {
|
||||
const delay = Math.min(2000 * Math.pow(2, this.reconnectAttempts), 30000);
|
||||
console.log(
|
||||
`[WebSocket] Reconnecting in ${delay}ms (attempt ${this.reconnectAttempts + 1}/${this.maxReconnectAttempts})`,
|
||||
);
|
||||
|
||||
this.reconnectTimeout = setTimeout(() => {
|
||||
this.reconnectAttempts++;
|
||||
// Note: Need to get token again, should be handled by caller
|
||||
console.warn(
|
||||
'[WebSocket] Auto-reconnect requires token, please reconnect manually',
|
||||
);
|
||||
}, delay);
|
||||
} else if (this.reconnectAttempts >= this.maxReconnectAttempts) {
|
||||
console.error(
|
||||
'[WebSocket] Max reconnect attempts reached, giving up',
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Disconnect from server
|
||||
*/
|
||||
disconnect() {
|
||||
this.isManualDisconnect = true;
|
||||
this.stopPing();
|
||||
|
||||
if (this.reconnectTimeout) {
|
||||
clearTimeout(this.reconnectTimeout);
|
||||
this.reconnectTimeout = null;
|
||||
}
|
||||
|
||||
if (this.ws) {
|
||||
this.ws.close(1000, 'Client disconnect');
|
||||
this.ws = null;
|
||||
}
|
||||
|
||||
console.log('[WebSocket] Disconnected');
|
||||
}
|
||||
|
||||
/**
|
||||
* Build WebSocket URL
|
||||
*/
|
||||
private buildWebSocketUrl(): string {
|
||||
// Get current base URL
|
||||
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
|
||||
const host = window.location.host;
|
||||
|
||||
return `${protocol}//${host}/api/v1/pipelines/${this.pipelineId}/chat/ws`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Generate unique client message ID
|
||||
*/
|
||||
private generateMessageId(): string {
|
||||
return `${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get connection state
|
||||
*/
|
||||
getState():
|
||||
| 'CONNECTING'
|
||||
| 'OPEN'
|
||||
| 'CLOSING'
|
||||
| 'CLOSED'
|
||||
| 'DISCONNECTED' {
|
||||
if (!this.ws) return 'DISCONNECTED';
|
||||
|
||||
switch (this.ws.readyState) {
|
||||
case WebSocket.CONNECTING:
|
||||
return 'CONNECTING';
|
||||
case WebSocket.OPEN:
|
||||
return 'OPEN';
|
||||
case WebSocket.CLOSING:
|
||||
return 'CLOSING';
|
||||
case WebSocket.CLOSED:
|
||||
return 'CLOSED';
|
||||
default:
|
||||
return 'DISCONNECTED';
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if connected
|
||||
*/
|
||||
isConnected(): boolean {
|
||||
return this.ws?.readyState === WebSocket.OPEN;
|
||||
}
|
||||
}
|
||||
@@ -39,9 +39,7 @@ const enUS = {
|
||||
deleteSuccess: 'Deleted successfully',
|
||||
deleteError: 'Delete failed: ',
|
||||
addRound: 'Add Round',
|
||||
copy: 'Copy',
|
||||
copySuccess: 'Copy Successfully',
|
||||
copyFailed: 'Copy Failed',
|
||||
test: 'Test',
|
||||
forgotPassword: 'Forgot Password?',
|
||||
loading: 'Loading...',
|
||||
@@ -151,10 +149,6 @@ const enUS = {
|
||||
log: 'Log',
|
||||
configuration: 'Configuration',
|
||||
logs: 'Logs',
|
||||
webhookUrl: 'Webhook Callback URL',
|
||||
webhookUrlCopied: 'Webhook URL copied',
|
||||
webhookUrlHint:
|
||||
'Click the input to select all, then press Ctrl+C (Mac: Cmd+C) to copy, or click the button',
|
||||
},
|
||||
plugins: {
|
||||
title: 'Plugins',
|
||||
|
||||
@@ -40,9 +40,7 @@ const jaJP = {
|
||||
deleteSuccess: '削除に成功しました',
|
||||
deleteError: '削除に失敗しました:',
|
||||
addRound: 'ラウンドを追加',
|
||||
copy: 'コピー',
|
||||
copySuccess: 'コピーに成功しました',
|
||||
copyFailed: 'コピーに失敗しました',
|
||||
test: 'テスト',
|
||||
forgotPassword: 'パスワードを忘れた?',
|
||||
loading: '読み込み中...',
|
||||
@@ -153,10 +151,6 @@ const jaJP = {
|
||||
log: 'ログ',
|
||||
configuration: '設定',
|
||||
logs: 'ログ',
|
||||
webhookUrl: 'Webhook コールバック URL',
|
||||
webhookUrlCopied: 'Webhook URL をコピーしました',
|
||||
webhookUrlHint:
|
||||
'入力ボックスをクリックして全選択し、Ctrl+C (Mac: Cmd+C) でコピーするか、右側のボタンをクリックしてください',
|
||||
},
|
||||
plugins: {
|
||||
title: 'プラグイン',
|
||||
|
||||
@@ -39,9 +39,7 @@ const zhHans = {
|
||||
deleteSuccess: '删除成功',
|
||||
deleteError: '删除失败:',
|
||||
addRound: '添加回合',
|
||||
copy: '复制',
|
||||
copySuccess: '复制成功',
|
||||
copyFailed: '复制失败',
|
||||
test: '测试',
|
||||
forgotPassword: '忘记密码?',
|
||||
loading: '加载中...',
|
||||
@@ -148,10 +146,6 @@ const zhHans = {
|
||||
log: '日志',
|
||||
configuration: '配置',
|
||||
logs: '日志',
|
||||
webhookUrl: 'Webhook 回调地址',
|
||||
webhookUrlCopied: 'Webhook 地址已复制',
|
||||
webhookUrlHint:
|
||||
'点击输入框自动全选,然后按 Ctrl+C (Mac: Cmd+C) 复制,或点击右侧按钮',
|
||||
},
|
||||
plugins: {
|
||||
title: '插件管理',
|
||||
|
||||
@@ -39,9 +39,7 @@ const zhHant = {
|
||||
deleteSuccess: '刪除成功',
|
||||
deleteError: '刪除失敗:',
|
||||
addRound: '新增回合',
|
||||
copy: '複製',
|
||||
copySuccess: '複製成功',
|
||||
copyFailed: '複製失敗',
|
||||
test: '測試',
|
||||
forgotPassword: '忘記密碼?',
|
||||
loading: '載入中...',
|
||||
@@ -148,10 +146,6 @@ const zhHant = {
|
||||
log: '日誌',
|
||||
configuration: '設定',
|
||||
logs: '日誌',
|
||||
webhookUrl: 'Webhook 回調位址',
|
||||
webhookUrlCopied: 'Webhook 位址已複製',
|
||||
webhookUrlHint:
|
||||
'點擊輸入框自動全選,然後按 Ctrl+C (Mac: Cmd+C) 複製,或點擊右側按鈕',
|
||||
},
|
||||
plugins: {
|
||||
title: '外掛管理',
|
||||
|
||||
Reference in New Issue
Block a user