Compare commits

..

1 Commits

Author SHA1 Message Date
RockChinQ
a4995c8cd9 feat: Migrate pipeline debug from SSE to WebSocket
Replace Server-Sent Events (SSE) with WebSocket for pipeline debugging
to support bidirectional real-time communication and resolve timeout
limitations.

## Backend Changes

- Add WebSocketConnectionPool for managing client connections
- Implement WebSocket route handler at /api/v1/pipelines/<uuid>/chat/ws
- Modify WebChatAdapter to broadcast messages via WebSocket
- Support both legacy SSE and new WebSocket simultaneously
- Maintain person/group session isolation

## Frontend Changes

- Create PipelineWebSocketClient for WebSocket communication
- Update DebugDialog to use WebSocket instead of SSE
- Support auto-reconnection and connection status tracking
- Remove streaming toggle (WebSocket always supports streaming)

## Key Features

- Eliminates 120-second timeout limitation
- Supports multiple messages per request
- Enables backend push notifications (plugins)
- Maintains session isolation (person vs group)
- Backward compatible with SSE during transition

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-10-22 12:50:18 +00:00
38 changed files with 1524 additions and 1465 deletions

View File

@@ -23,25 +23,20 @@ xml_template = """
class OAClient:
def __init__(self, token: str, EncodingAESKey: str, AppID: str, Appsecret: str, logger: None, unified_mode: bool = False):
def __init__(self, token: str, EncodingAESKey: str, AppID: str, Appsecret: str, logger: None):
self.token = token
self.aes = EncodingAESKey
self.appid = AppID
self.appsecret = Appsecret
self.base_url = 'https://api.weixin.qq.com'
self.access_token = ''
self.unified_mode = unified_mode
self.app = Quart(__name__)
# 只有在非统一模式下才注册独立路由
if not self.unified_mode:
self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['GET', 'POST'],
)
self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['GET', 'POST'],
)
self._message_handlers = {
'example': [],
}
@@ -51,39 +46,19 @@ class OAClient:
self.logger = logger
async def handle_callback_request(self):
"""处理回调请求(独立端口模式,使用全局 request"""
return await self._handle_callback_internal(request)
async def handle_unified_webhook(self, req):
"""处理回调请求(统一 webhook 模式,显式传递 request
Args:
req: Quart Request 对象
Returns:
响应数据
"""
return await self._handle_callback_internal(req)
async def _handle_callback_internal(self, req):
"""处理回调请求的内部实现,包括 GET 验证和 POST 消息接收。
Args:
req: Quart Request 对象
"""
try:
# 每隔100毫秒查询是否生成ai回答
start_time = time.time()
signature = req.args.get('signature', '')
timestamp = req.args.get('timestamp', '')
nonce = req.args.get('nonce', '')
echostr = req.args.get('echostr', '')
msg_signature = req.args.get('msg_signature', '')
signature = request.args.get('signature', '')
timestamp = request.args.get('timestamp', '')
nonce = request.args.get('nonce', '')
echostr = request.args.get('echostr', '')
msg_signature = request.args.get('msg_signature', '')
if msg_signature is None:
await self.logger.error('msg_signature不在请求体中')
raise Exception('msg_signature不在请求体中')
if req.method == 'GET':
if request.method == 'GET':
# 校验签名
check_str = ''.join(sorted([self.token, timestamp, nonce]))
check_signature = hashlib.sha1(check_str.encode('utf-8')).hexdigest()
@@ -93,8 +68,8 @@ class OAClient:
else:
await self.logger.error('拒绝请求')
raise Exception('拒绝请求')
elif req.method == 'POST':
encryt_msg = await req.data
elif request.method == 'POST':
encryt_msg = await request.data
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce)
xml_msg = xml_msg.decode('utf-8')
@@ -207,7 +182,6 @@ class OAClientForLongerResponse:
Appsecret: str,
LoadingMessage: str,
logger: None,
unified_mode: bool = False,
):
self.token = token
self.aes = EncodingAESKey
@@ -215,18 +189,13 @@ class OAClientForLongerResponse:
self.appsecret = Appsecret
self.base_url = 'https://api.weixin.qq.com'
self.access_token = ''
self.unified_mode = unified_mode
self.app = Quart(__name__)
# 只有在非统一模式下才注册独立路由
if not self.unified_mode:
self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['GET', 'POST'],
)
self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['GET', 'POST'],
)
self._message_handlers = {
'example': [],
}
@@ -237,44 +206,24 @@ class OAClientForLongerResponse:
self.logger = logger
async def handle_callback_request(self):
"""处理回调请求(独立端口模式,使用全局 request"""
return await self._handle_callback_internal(request)
async def handle_unified_webhook(self, req):
"""处理回调请求(统一 webhook 模式,显式传递 request
Args:
req: Quart Request 对象
Returns:
响应数据
"""
return await self._handle_callback_internal(req)
async def _handle_callback_internal(self, req):
"""处理回调请求的内部实现,包括 GET 验证和 POST 消息接收。
Args:
req: Quart Request 对象
"""
try:
signature = req.args.get('signature', '')
timestamp = req.args.get('timestamp', '')
nonce = req.args.get('nonce', '')
echostr = req.args.get('echostr', '')
msg_signature = req.args.get('msg_signature', '')
signature = request.args.get('signature', '')
timestamp = request.args.get('timestamp', '')
nonce = request.args.get('nonce', '')
echostr = request.args.get('echostr', '')
msg_signature = request.args.get('msg_signature', '')
if msg_signature is None:
await self.logger.error('msg_signature不在请求体中')
raise Exception('msg_signature不在请求体中')
if req.method == 'GET':
if request.method == 'GET':
check_str = ''.join(sorted([self.token, timestamp, nonce]))
check_signature = hashlib.sha1(check_str.encode('utf-8')).hexdigest()
return echostr if check_signature == signature else '拒绝请求'
elif req.method == 'POST':
encryt_msg = await req.data
elif request.method == 'POST':
encryt_msg = await request.data
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.appid)
ret, xml_msg = wxcpt.DecryptMsg(encryt_msg, msg_signature, timestamp, nonce)
xml_msg = xml_msg.decode('utf-8')

View File

@@ -10,20 +10,38 @@ import traceback
from cryptography.hazmat.primitives.asymmetric import ed25519
def handle_validation(body: dict, bot_secret: str):
# bot正确的secert是32位的此处仅为了适配演示demo
while len(bot_secret) < 32:
bot_secret = bot_secret * 2
bot_secret = bot_secret[:32]
# 实际使用场景中以上三行内容可清除
seed_bytes = bot_secret.encode()
signing_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed_bytes)
msg = body['d']['event_ts'] + body['d']['plain_token']
msg_bytes = msg.encode()
signature = signing_key.sign(msg_bytes)
signature_hex = signature.hex()
response = {'plain_token': body['d']['plain_token'], 'signature': signature_hex}
return response
class QQOfficialClient:
def __init__(self, secret: str, token: str, app_id: str, logger: None, unified_mode: bool = False):
self.unified_mode = unified_mode
def __init__(self, secret: str, token: str, app_id: str, logger: None):
self.app = Quart(__name__)
# 只有在非统一模式下才注册独立路由
if not self.unified_mode:
self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['GET', 'POST'],
)
self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['GET', 'POST'],
)
self.secret = secret
self.token = token
self.app_id = app_id
@@ -64,45 +82,18 @@ class QQOfficialClient:
raise Exception(f'获取access_token失败: {e}')
async def handle_callback_request(self):
"""处理回调请求(独立端口模式,使用全局 request"""
return await self._handle_callback_internal(request)
async def handle_unified_webhook(self, req):
"""处理回调请求(统一 webhook 模式,显式传递 request
Args:
req: Quart Request 对象
Returns:
响应数据
"""
return await self._handle_callback_internal(req)
async def _handle_callback_internal(self, req):
"""处理回调请求的内部实现。
Args:
req: Quart Request 对象
"""
"""处理回调请求"""
try:
body = await req.get_data()
print(f'[QQ Official] Received request, body length: {len(body)}')
if not body or len(body) == 0:
print('[QQ Official] Received empty body, might be health check or GET request')
return {'code': 0, 'message': 'ok'}, 200
# 读取请求数据
body = await request.get_data()
payload = json.loads(body)
# 验证是否为回调验证请求
if payload.get('op') == 13:
validation_data = payload.get('d')
if not validation_data:
return {'error': "missing 'd' field"}, 400
response = await self.verify(validation_data)
return response, 200
# 生成签名
response = handle_validation(payload, self.secret)
return response
if payload.get('op') == 0:
message_data = await self.get_message(payload)
@@ -113,7 +104,6 @@ class QQOfficialClient:
return {'code': 0, 'message': 'success'}
except Exception as e:
print(f'[QQ Official] ERROR: {traceback.format_exc()}')
await self.logger.error(f'Error in handle_callback_request: {traceback.format_exc()}')
return {'error': str(e)}, 400
@@ -271,26 +261,3 @@ class QQOfficialClient:
if self.access_token_expiry_time is None:
return True
return time.time() > self.access_token_expiry_time
async def repeat_seed(self, bot_secret: str, target_size: int = 32) -> bytes:
seed = bot_secret
while len(seed) < target_size:
seed *= 2
return seed[:target_size].encode("utf-8")
async def verify(self, validation_payload: dict):
seed = await self.repeat_seed(self.secret)
private_key = ed25519.Ed25519PrivateKey.from_private_bytes(seed)
event_ts = validation_payload.get("event_ts", "")
plain_token = validation_payload.get("plain_token", "")
msg = event_ts + plain_token
# sign
signature = private_key.sign(msg.encode()).hex()
response = {
"plain_token": plain_token,
"signature": signature,
}
return response

View File

@@ -8,19 +8,14 @@ import langbot_plugin.api.entities.builtin.platform.events as platform_events
class SlackClient:
def __init__(self, bot_token: str, signing_secret: str, logger: None, unified_mode: bool = False):
def __init__(self, bot_token: str, signing_secret: str, logger: None):
self.bot_token = bot_token
self.signing_secret = signing_secret
self.unified_mode = unified_mode
self.app = Quart(__name__)
self.client = AsyncWebClient(self.bot_token)
# 只有在非统一模式下才注册独立路由
if not self.unified_mode:
self.app.add_url_rule(
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']
)
self.app.add_url_rule(
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']
)
self._message_handlers = {
'example': [],
}
@@ -28,28 +23,8 @@ class SlackClient:
self.logger = logger
async def handle_callback_request(self):
"""处理回调请求(独立端口模式,使用全局 request"""
return await self._handle_callback_internal(request)
async def handle_unified_webhook(self, req):
"""处理回调请求(统一 webhook 模式,显式传递 request
Args:
req: Quart Request 对象
Returns:
响应数据
"""
return await self._handle_callback_internal(req)
async def _handle_callback_internal(self, req):
"""处理回调请求的内部实现。
Args:
req: Quart Request 对象
"""
try:
body = await req.get_data()
body = await request.get_data()
data = json.loads(body)
if 'type' in data:
if data['type'] == 'url_verification':

View File

@@ -1,477 +1,189 @@
import asyncio
import base64
import json
import time
import traceback
import uuid
import xml.etree.ElementTree as ET
from dataclasses import dataclass, field
from typing import Any, Callable, Optional
from urllib.parse import unquote
import hashlib
import traceback
import httpx
from Crypto.Cipher import AES
from quart import Quart, request, Response, jsonify
from libs.wecom_ai_bot_api import wecombotevent
from libs.wecom_ai_bot_api.WXBizMsgCrypt3 import WXBizMsgCrypt
from quart import Quart, request, Response, jsonify
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import asyncio
from libs.wecom_ai_bot_api import wecombotevent
from typing import Callable
import base64
from Crypto.Cipher import AES
from pkg.platform.logger import EventLogger
@dataclass
class StreamChunk:
"""描述单次推送给企业微信的流式片段。"""
# 需要返回给企业微信的文本内容
content: str
# 标记是否为最终片段,对应企业微信协议里的 finish 字段
is_final: bool = False
# 预留额外元信息,未来支持多模态扩展时可使用
meta: dict[str, Any] = field(default_factory=dict)
@dataclass
class StreamSession:
"""维护一次企业微信流式会话的上下文。"""
# 企业微信要求的 stream_id用于标识后续刷新请求
stream_id: str
# 原始消息的 msgid便于与流水线消息对应
msg_id: str
# 群聊会话标识(单聊时为空)
chat_id: Optional[str]
# 触发消息的发送者
user_id: Optional[str]
# 会话创建时间
created_at: float = field(default_factory=time.time)
# 最近一次被访问的时间cleanup 依据该值判断过期
last_access: float = field(default_factory=time.time)
# 将流水线增量结果缓存到队列,刷新请求逐条消费
queue: asyncio.Queue = field(default_factory=asyncio.Queue)
# 是否已经完成(收到最终片段)
finished: bool = False
# 缓存最近一次片段,处理重试或超时兜底
last_chunk: Optional[StreamChunk] = None
class StreamSessionManager:
"""管理 stream 会话的生命周期,并负责队列的生产消费。"""
def __init__(self, logger: EventLogger, ttl: int = 60) -> None:
self.logger = logger
self.ttl = ttl # 超时时间(秒),超过该时间未被访问的会话会被清理由 cleanup
self._sessions: dict[str, StreamSession] = {} # stream_id -> StreamSession 映射
self._msg_index: dict[str, str] = {} # msgid -> stream_id 映射,便于流水线根据消息 ID 找到会话
def get_stream_id_by_msg(self, msg_id: str) -> Optional[str]:
if not msg_id:
return None
return self._msg_index.get(msg_id)
def get_session(self, stream_id: str) -> Optional[StreamSession]:
return self._sessions.get(stream_id)
def create_or_get(self, msg_json: dict[str, Any]) -> tuple[StreamSession, bool]:
"""根据企业微信回调创建或获取会话。
Args:
msg_json: 企业微信解密后的回调 JSON。
Returns:
Tuple[StreamSession, bool]: `StreamSession` 为会话实例,`bool` 指示是否为新建会话。
Example:
在首次回调中调用,得到 `is_new=True` 后再触发流水线。
"""
msg_id = msg_json.get('msgid', '')
if msg_id and msg_id in self._msg_index:
stream_id = self._msg_index[msg_id]
session = self._sessions.get(stream_id)
if session:
session.last_access = time.time()
return session, False
stream_id = str(uuid.uuid4())
session = StreamSession(
stream_id=stream_id,
msg_id=msg_id,
chat_id=msg_json.get('chatid'),
user_id=msg_json.get('from', {}).get('userid'),
)
if msg_id:
self._msg_index[msg_id] = stream_id
self._sessions[stream_id] = session
return session, True
async def publish(self, stream_id: str, chunk: StreamChunk) -> bool:
"""向 stream 队列写入新的增量片段。
Args:
stream_id: 企业微信分配的流式会话 ID。
chunk: 待发送的增量片段。
Returns:
bool: 当流式队列存在并成功入队时返回 True。
Example:
在收到模型增量后调用 `await manager.publish('sid', StreamChunk('hello'))`。
"""
session = self._sessions.get(stream_id)
if not session:
return False
session.last_access = time.time()
session.last_chunk = chunk
try:
session.queue.put_nowait(chunk)
except asyncio.QueueFull:
# 默认无界队列,此处兜底防御
await session.queue.put(chunk)
if chunk.is_final:
session.finished = True
return True
async def consume(self, stream_id: str, timeout: float = 0.5) -> Optional[StreamChunk]:
"""从队列中取出一个片段,若超时返回 None。
Args:
stream_id: 企业微信流式会话 ID。
timeout: 取片段的最长等待时间(秒)。
Returns:
Optional[StreamChunk]: 成功时返回片段,超时或会话不存在时返回 None。
Example:
企业微信刷新到达时调用,若队列有数据则立即返回 `StreamChunk`。
"""
session = self._sessions.get(stream_id)
if not session:
return None
session.last_access = time.time()
try:
chunk = await asyncio.wait_for(session.queue.get(), timeout)
session.last_access = time.time()
if chunk.is_final:
session.finished = True
return chunk
except asyncio.TimeoutError:
if session.finished and session.last_chunk:
return session.last_chunk
return None
def mark_finished(self, stream_id: str) -> None:
session = self._sessions.get(stream_id)
if session:
session.finished = True
session.last_access = time.time()
def cleanup(self) -> None:
"""定期清理过期会话,防止队列与映射无上限累积。"""
now = time.time()
expired: list[str] = []
for stream_id, session in self._sessions.items():
if now - session.last_access > self.ttl:
expired.append(stream_id)
for stream_id in expired:
session = self._sessions.pop(stream_id, None)
if not session:
continue
msg_id = session.msg_id
if msg_id and self._msg_index.get(msg_id) == stream_id:
self._msg_index.pop(msg_id, None)
class WecomBotClient:
def __init__(self, Token: str, EnCodingAESKey: str, Corpid: str, logger: EventLogger, unified_mode: bool = False):
"""企业微信智能机器人客户端。
Args:
Token: 企业微信回调验证使用的 token。
EnCodingAESKey: 企业微信消息加解密密钥。
Corpid: 企业 ID。
logger: 日志记录器。
unified_mode: 是否使用统一 webhook 模式(默认 False
Example:
>>> client = WecomBotClient(Token='token', EnCodingAESKey='aeskey', Corpid='corp', logger=logger)
"""
self.Token = Token
self.EnCodingAESKey = EnCodingAESKey
self.Corpid = Corpid
def __init__(self,Token:str,EnCodingAESKey:str,Corpid:str,logger:EventLogger):
self.Token=Token
self.EnCodingAESKey=EnCodingAESKey
self.Corpid=Corpid
self.ReceiveId = ''
self.unified_mode = unified_mode
self.app = Quart(__name__)
# 只有在非统一模式下才注册独立路由
if not self.unified_mode:
self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['POST', 'GET']
)
self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['POST','GET']
)
self._message_handlers = {
'example': [],
}
self.user_stream_map = {}
self.logger = logger
self.generated_content: dict[str, str] = {}
self.msg_id_map: dict[str, int] = {}
self.stream_sessions = StreamSessionManager(logger=logger)
self.stream_poll_timeout = 0.5
@staticmethod
def _build_stream_payload(stream_id: str, content: str, finish: bool) -> dict[str, Any]:
"""按照企业微信协议拼装返回报文。
Args:
stream_id: 企业微信会话 ID。
content: 推送的文本内容。
finish: 是否为最终片段。
Returns:
dict[str, Any]: 可直接加密返回的 payload。
Example:
组装 `{'msgtype': 'stream', 'stream': {'id': 'sid', ...}}` 结构。
"""
return {
'msgtype': 'stream',
'stream': {
'id': stream_id,
'finish': finish,
'content': content,
},
}
async def _encrypt_and_reply(self, payload: dict[str, Any], nonce: str) -> tuple[Response, int]:
"""对响应进行加密封装并返回给企业微信。
Args:
payload: 待加密的响应内容。
nonce: 企业微信回调参数中的 nonce。
Returns:
Tuple[Response, int]: Quart Response 对象及状态码。
Example:
在首包或刷新场景中调用以生成加密响应。
"""
reply_plain_str = json.dumps(payload, ensure_ascii=False)
reply_timestamp = str(int(time.time()))
ret, encrypt_text = self.wxcpt.EncryptMsg(reply_plain_str, nonce, reply_timestamp)
if ret != 0:
await self.logger.error(f'加密失败: {ret}')
return jsonify({'error': 'encrypt_failed'}), 500
root = ET.fromstring(encrypt_text)
encrypt = root.find('Encrypt').text
resp = {
'encrypt': encrypt,
}
return jsonify(resp), 200
async def _dispatch_event(self, event: wecombotevent.WecomBotEvent) -> None:
"""异步触发流水线处理,避免阻塞首包响应。
Args:
event: 由企业微信消息转换的内部事件对象。
"""
try:
await self._handle_message(event)
except Exception:
await self.logger.error(traceback.format_exc())
async def _handle_post_initial_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
"""处理企业微信首次推送的消息,返回 stream_id 并开启流水线。
Args:
msg_json: 解密后的企业微信消息 JSON。
nonce: 企业微信回调参数 nonce。
Returns:
Tuple[Response, int]: Quart Response 及状态码。
Example:
首次回调时调用,立即返回带 `stream_id` 的响应。
"""
session, is_new = self.stream_sessions.create_or_get(msg_json)
message_data = await self.get_message(msg_json)
if message_data:
message_data['stream_id'] = session.stream_id
try:
event = wecombotevent.WecomBotEvent(message_data)
except Exception:
await self.logger.error(traceback.format_exc())
else:
if is_new:
asyncio.create_task(self._dispatch_event(event))
payload = self._build_stream_payload(session.stream_id, '', False)
return await self._encrypt_and_reply(payload, nonce)
async def _handle_post_followup_response(self, msg_json: dict[str, Any], nonce: str) -> tuple[Response, int]:
"""处理企业微信的流式刷新请求,按需返回增量片段。
Args:
msg_json: 解密后的企业微信刷新请求。
nonce: 企业微信回调参数 nonce。
Returns:
Tuple[Response, int]: Quart Response 及状态码。
Example:
在刷新请求中调用,按需返回增量片段。
"""
stream_info = msg_json.get('stream', {})
stream_id = stream_info.get('id', '')
if not stream_id:
await self.logger.error('刷新请求缺少 stream.id')
return await self._encrypt_and_reply(self._build_stream_payload('', '', True), nonce)
session = self.stream_sessions.get_session(stream_id)
chunk = await self.stream_sessions.consume(stream_id, timeout=self.stream_poll_timeout)
if not chunk:
cached_content = None
if session and session.msg_id:
cached_content = self.generated_content.pop(session.msg_id, None)
if cached_content is not None:
chunk = StreamChunk(content=cached_content, is_final=True)
else:
payload = self._build_stream_payload(stream_id, '', False)
return await self._encrypt_and_reply(payload, nonce)
payload = self._build_stream_payload(stream_id, chunk.content, chunk.is_final)
if chunk.is_final:
self.stream_sessions.mark_finished(stream_id)
return await self._encrypt_and_reply(payload, nonce)
self.generated_content = {}
self.msg_id_map = {}
async def sha1_signature(token: str, timestamp: str, nonce: str, encrypt: str) -> str:
raw = "".join(sorted([token, timestamp, nonce, encrypt]))
return hashlib.sha1(raw.encode("utf-8")).hexdigest()
async def handle_callback_request(self):
"""企业微信回调入口(独立端口模式,使用全局 request
Returns:
Quart Response: 根据请求类型返回验证、首包或刷新结果。
Example:
作为 Quart 路由处理函数直接注册并使用。
"""
return await self._handle_callback_internal(request)
async def handle_unified_webhook(self, req):
"""处理回调请求(统一 webhook 模式,显式传递 request
Args:
req: Quart Request 对象
Returns:
响应数据
"""
return await self._handle_callback_internal(req)
async def _handle_callback_internal(self, req):
"""处理回调请求的内部实现,包括 GET 验证和 POST 消息接收。
Args:
req: Quart Request 对象
"""
try:
self.wxcpt = WXBizMsgCrypt(self.Token, self.EnCodingAESKey, '')
await self.logger.info(f'{req.method} {req.url} {str(req.args)}')
self.wxcpt=WXBizMsgCrypt(self.Token,self.EnCodingAESKey,'')
if req.method == 'GET':
return await self._handle_get_callback(req)
if request.method == "GET":
if req.method == 'POST':
return await self._handle_post_callback(req)
msg_signature = unquote(request.args.get("msg_signature", ""))
timestamp = unquote(request.args.get("timestamp", ""))
nonce = unquote(request.args.get("nonce", ""))
echostr = unquote(request.args.get("echostr", ""))
return Response('', status=405)
if not all([msg_signature, timestamp, nonce, echostr]):
await self.logger.error("请求参数缺失")
return Response("缺少参数", status=400)
except Exception:
ret, decrypted_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
if ret != 0:
await self.logger.error("验证URL失败")
return Response("验证失败", status=403)
return Response(decrypted_str, mimetype="text/plain")
elif request.method == "POST":
msg_signature = unquote(request.args.get("msg_signature", ""))
timestamp = unquote(request.args.get("timestamp", ""))
nonce = unquote(request.args.get("nonce", ""))
try:
timeout = 3
interval = 0.1
start_time = time.monotonic()
encrypted_json = await request.get_json()
encrypted_msg = encrypted_json.get("encrypt", "")
if not encrypted_msg:
await self.logger.error("请求体中缺少 'encrypt' 字段")
xml_post_data = f"<xml><Encrypt><![CDATA[{encrypted_msg}]]></Encrypt></xml>"
ret, decrypted_xml = self.wxcpt.DecryptMsg(xml_post_data, msg_signature, timestamp, nonce)
if ret != 0:
await self.logger.error("解密失败")
msg_json = json.loads(decrypted_xml)
from_user_id = msg_json.get("from", {}).get("userid")
chatid = msg_json.get("chatid", "")
message_data = await self.get_message(msg_json)
if message_data:
try:
event = wecombotevent.WecomBotEvent(message_data)
if event:
await self._handle_message(event)
except Exception as e:
await self.logger.error(traceback.format_exc())
print(traceback.format_exc())
start_time = time.time()
try:
if msg_json.get('chattype','') == 'single':
if from_user_id in self.user_stream_map:
stream_id = self.user_stream_map[from_user_id]
else:
stream_id =str(uuid.uuid4())
self.user_stream_map[from_user_id] = stream_id
else:
if chatid in self.user_stream_map:
stream_id = self.user_stream_map[chatid]
else:
stream_id = str(uuid.uuid4())
self.user_stream_map[chatid] = stream_id
except Exception as e:
await self.logger.error(traceback.format_exc())
print(traceback.format_exc())
while True:
content = self.generated_content.pop(msg_json['msgid'],None)
if content:
reply_plain = {
"msgtype": "stream",
"stream": {
"id": stream_id,
"finish": True,
"content": content
}
}
reply_plain_str = json.dumps(reply_plain, ensure_ascii=False)
reply_timestamp = str(int(time.time()))
ret, encrypt_text = self.wxcpt.EncryptMsg(reply_plain_str, nonce, reply_timestamp)
if ret != 0:
await self.logger.error("加密失败"+str(ret))
root = ET.fromstring(encrypt_text)
encrypt = root.find("Encrypt").text
resp = {
"encrypt": encrypt,
}
return jsonify(resp), 200
if time.time() - start_time > timeout:
break
await asyncio.sleep(interval)
if self.msg_id_map.get(message_data['msgid'], 1) == 3:
await self.logger.error('请求失效暂不支持智能机器人超过7秒的请求如有需求请联系 LangBot 团队。')
return ''
except Exception as e:
await self.logger.error(traceback.format_exc())
print(traceback.format_exc())
except Exception as e:
await self.logger.error(traceback.format_exc())
return Response('Internal Server Error', status=500)
print(traceback.format_exc())
async def _handle_get_callback(self, req) -> tuple[Response, int] | Response:
"""处理企业微信的 GET 验证请求。"""
msg_signature = unquote(req.args.get('msg_signature', ''))
timestamp = unquote(req.args.get('timestamp', ''))
nonce = unquote(req.args.get('nonce', ''))
echostr = unquote(req.args.get('echostr', ''))
if not all([msg_signature, timestamp, nonce, echostr]):
await self.logger.error('请求参数缺失')
return Response('缺少参数', status=400)
ret, decrypted_str = self.wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
if ret != 0:
await self.logger.error('验证URL失败')
return Response('验证失败', status=403)
return Response(decrypted_str, mimetype='text/plain')
async def _handle_post_callback(self, req) -> tuple[Response, int] | Response:
"""处理企业微信的 POST 回调请求。"""
self.stream_sessions.cleanup()
msg_signature = unquote(req.args.get('msg_signature', ''))
timestamp = unquote(req.args.get('timestamp', ''))
nonce = unquote(req.args.get('nonce', ''))
encrypted_json = await req.get_json()
encrypted_msg = (encrypted_json or {}).get('encrypt', '')
if not encrypted_msg:
await self.logger.error("请求体中缺少 'encrypt' 字段")
return Response('Bad Request', status=400)
xml_post_data = f"<xml><Encrypt><![CDATA[{encrypted_msg}]]></Encrypt></xml>"
ret, decrypted_xml = self.wxcpt.DecryptMsg(xml_post_data, msg_signature, timestamp, nonce)
if ret != 0:
await self.logger.error('解密失败')
return Response('解密失败', status=400)
msg_json = json.loads(decrypted_xml)
if msg_json.get('msgtype') == 'stream':
return await self._handle_post_followup_response(msg_json, nonce)
return await self._handle_post_initial_response(msg_json, nonce)
async def get_message(self, msg_json):
async def get_message(self,msg_json):
message_data = {}
if msg_json.get('chattype', '') == 'single':
if msg_json.get('chattype','') == 'single':
message_data['type'] = 'single'
elif msg_json.get('chattype', '') == 'group':
elif msg_json.get('chattype','') == 'group':
message_data['type'] = 'group'
if msg_json.get('msgtype') == 'text':
message_data['content'] = msg_json.get('text', {}).get('content')
message_data['content'] = msg_json.get('text',{}).get('content')
elif msg_json.get('msgtype') == 'image':
picurl = msg_json.get('image', {}).get('url', '')
base64 = await self.download_url_to_base64(picurl, self.EnCodingAESKey)
message_data['picurl'] = base64
picurl = msg_json.get('image', {}).get('url','')
base64 = await self.download_url_to_base64(picurl,self.EnCodingAESKey)
message_data['picurl'] = base64
elif msg_json.get('msgtype') == 'mixed':
items = msg_json.get('mixed', {}).get('msg_item', [])
texts = []
@@ -485,27 +197,17 @@ class WecomBotClient:
if texts:
message_data['content'] = "".join(texts) # 拼接所有 text
if picurl:
base64 = await self.download_url_to_base64(picurl, self.EnCodingAESKey)
message_data['picurl'] = base64 # 只保留第一个 image
# Extract user information
from_info = msg_json.get('from', {})
message_data['userid'] = from_info.get('userid', '')
message_data['username'] = from_info.get('alias', '') or from_info.get('name', '') or from_info.get('userid', '')
# Extract chat/group information
if msg_json.get('chattype', '') == 'group':
message_data['chatid'] = msg_json.get('chatid', '')
# Try to get group name if available
message_data['chatname'] = msg_json.get('chatname', '') or msg_json.get('chatid', '')
base64 = await self.download_url_to_base64(picurl,self.EnCodingAESKey)
message_data['picurl'] = base64 # 只保留第一个 image
message_data['userid'] = msg_json.get('from', {}).get('userid', '')
message_data['msgid'] = msg_json.get('msgid', '')
if msg_json.get('aibotid'):
message_data['aibotid'] = msg_json.get('aibotid', '')
return message_data
async def _handle_message(self, event: wecombotevent.WecomBotEvent):
"""
处理消息事件。
@@ -521,46 +223,10 @@ class WecomBotClient:
for handler in self._message_handlers[msg_type]:
await handler(event)
except Exception:
print(traceback.format_exc())
async def push_stream_chunk(self, msg_id: str, content: str, is_final: bool = False) -> bool:
"""将流水线片段推送到 stream 会话。
Args:
msg_id: 原始企业微信消息 ID。
content: 模型产生的片段内容。
is_final: 是否为最终片段。
Returns:
bool: 当成功写入流式队列时返回 True。
Example:
在流水线 `reply_message_chunk` 中调用,将增量推送至企业微信。
"""
# 根据 msg_id 找到对应 stream 会话,如果不存在说明当前消息非流式
stream_id = self.stream_sessions.get_stream_id_by_msg(msg_id)
if not stream_id:
return False
chunk = StreamChunk(content=content, is_final=is_final)
await self.stream_sessions.publish(stream_id, chunk)
if is_final:
self.stream_sessions.mark_finished(stream_id)
return True
print(traceback.format_exc())
async def set_message(self, msg_id: str, content: str):
"""兼容旧逻辑:若无法流式返回则缓存最终结果。
Args:
msg_id: 企业微信消息 ID。
content: 最终回复的文本内容。
Example:
在非流式场景下缓存最终结果以备刷新时返回。
"""
handled = await self.push_stream_chunk(msg_id, content, is_final=True)
if not handled:
self.generated_content[msg_id] = content
self.generated_content[msg_id] = content
def on_message(self, msg_type: str):
def decorator(func: Callable[[wecombotevent.WecomBotEvent], None]):
@@ -571,6 +237,7 @@ class WecomBotClient:
return decorator
async def download_url_to_base64(self, download_url, encoding_aes_key):
async with httpx.AsyncClient() as client:
response = await client.get(download_url)
@@ -580,22 +247,26 @@ class WecomBotClient:
encrypted_bytes = response.content
aes_key = base64.b64decode(encoding_aes_key + "=") # base64 补齐
iv = aes_key[:16]
cipher = AES.new(aes_key, AES.MODE_CBC, iv)
decrypted = cipher.decrypt(encrypted_bytes)
pad_len = decrypted[-1]
decrypted = decrypted[:-pad_len]
if decrypted.startswith(b"\xff\xd8"): # JPEG
if decrypted.startswith(b"\xff\xd8"): # JPEG
mime_type = "image/jpeg"
elif decrypted.startswith(b"\x89PNG"): # PNG
mime_type = "image/png"
elif decrypted.startswith((b"GIF87a", b"GIF89a")): # GIF
mime_type = "image/gif"
elif decrypted.startswith(b"BM"): # BMP
elif decrypted.startswith(b"BM"): # BMP
mime_type = "image/bmp"
elif decrypted.startswith(b"II*\x00") or decrypted.startswith(b"MM\x00*"): # TIFF
mime_type = "image/tiff"
@@ -605,9 +276,15 @@ class WecomBotClient:
# 转 base64
base64_str = base64.b64encode(decrypted).decode("utf-8")
return f"data:{mime_type};base64,{base64_str}"
async def run_task(self, host: str, port: int, *args, **kwargs):
"""
启动 Quart 应用。
"""
await self.app.run_task(host=host, port=port, *args, **kwargs)

View File

@@ -22,21 +22,7 @@ class WecomBotEvent(dict):
"""
用户id
"""
return self.get('from', {}).get('userid', '') or self.get('userid', '')
@property
def username(self) -> str:
"""
用户名称
"""
return self.get('username', '') or self.get('from', {}).get('alias', '') or self.get('from', {}).get('name', '') or self.userid
@property
def chatname(self) -> str:
"""
群组名称
"""
return self.get('chatname', '') or str(self.chatid)
return self.get('from', {}).get('userid', '')
@property
def content(self) -> str:

View File

@@ -21,7 +21,6 @@ class WecomClient:
EncodingAESKey: str,
contacts_secret: str,
logger: None,
unified_mode: bool = False,
):
self.corpid = corpid
self.secret = secret
@@ -32,18 +31,13 @@ class WecomClient:
self.access_token = ''
self.secret_for_contacts = contacts_secret
self.logger = logger
self.unified_mode = unified_mode
self.app = Quart(__name__)
# 只有在非统一模式下才注册独立路由
if not self.unified_mode:
self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['GET', 'POST'],
)
self.app.add_url_rule(
'/callback/command',
'handle_callback',
self.handle_callback_request,
methods=['GET', 'POST'],
)
self._message_handlers = {
'example': [],
}
@@ -174,43 +168,25 @@ class WecomClient:
raise Exception('Failed to send message: ' + str(data))
async def handle_callback_request(self):
"""处理回调请求(独立端口模式,使用全局 request"""
return await self._handle_callback_internal(request)
async def handle_unified_webhook(self, req):
"""处理回调请求(统一 webhook 模式,显式传递 request
Args:
req: Quart Request 对象
Returns:
响应数据
"""
return await self._handle_callback_internal(req)
async def _handle_callback_internal(self, req):
"""
处理回调请求的内部实现,包括 GET 验证和 POST 消息接收。
Args:
req: Quart Request 对象
处理回调请求,包括 GET 验证和 POST 消息接收。
"""
try:
msg_signature = req.args.get('msg_signature')
timestamp = req.args.get('timestamp')
nonce = req.args.get('nonce')
msg_signature = request.args.get('msg_signature')
timestamp = request.args.get('timestamp')
nonce = request.args.get('nonce')
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid)
if req.method == 'GET':
echostr = req.args.get('echostr')
if request.method == 'GET':
echostr = request.args.get('echostr')
ret, reply_echo_str = wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
if ret != 0:
await self.logger.error('验证失败')
raise Exception(f'验证失败,错误码: {ret}')
return reply_echo_str
elif req.method == 'POST':
encrypt_msg = await req.data
elif request.method == 'POST':
encrypt_msg = await request.data
ret, xml_msg = wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
if ret != 0:
await self.logger.error('消息解密失败')

View File

@@ -13,7 +13,7 @@ import aiofiles
class WecomCSClient:
def __init__(self, corpid: str, secret: str, token: str, EncodingAESKey: str, logger: None, unified_mode: bool = False):
def __init__(self, corpid: str, secret: str, token: str, EncodingAESKey: str, logger: None):
self.corpid = corpid
self.secret = secret
self.access_token_for_contacts = ''
@@ -22,15 +22,10 @@ class WecomCSClient:
self.base_url = 'https://qyapi.weixin.qq.com/cgi-bin'
self.access_token = ''
self.logger = logger
self.unified_mode = unified_mode
self.app = Quart(__name__)
# 只有在非统一模式下才注册独立路由
if not self.unified_mode:
self.app.add_url_rule(
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']
)
self.app.add_url_rule(
'/callback/command', 'handle_callback', self.handle_callback_request, methods=['GET', 'POST']
)
self._message_handlers = {
'example': [],
}
@@ -197,45 +192,27 @@ class WecomCSClient:
return data
async def handle_callback_request(self):
"""处理回调请求(独立端口模式,使用全局 request"""
return await self._handle_callback_internal(request)
async def handle_unified_webhook(self, req):
"""处理回调请求(统一 webhook 模式,显式传递 request
Args:
req: Quart Request 对象
Returns:
响应数据
"""
return await self._handle_callback_internal(req)
async def _handle_callback_internal(self, req):
"""
处理回调请求的内部实现,包括 GET 验证和 POST 消息接收。
Args:
req: Quart Request 对象
处理回调请求,包括 GET 验证和 POST 消息接收。
"""
try:
msg_signature = req.args.get('msg_signature')
timestamp = req.args.get('timestamp')
nonce = req.args.get('nonce')
msg_signature = request.args.get('msg_signature')
timestamp = request.args.get('timestamp')
nonce = request.args.get('nonce')
try:
wxcpt = WXBizMsgCrypt(self.token, self.aes, self.corpid)
except Exception as e:
raise Exception(f'初始化失败,错误码: {e}')
if req.method == 'GET':
echostr = req.args.get('echostr')
if request.method == 'GET':
echostr = request.args.get('echostr')
ret, reply_echo_str = wxcpt.VerifyURL(msg_signature, timestamp, nonce, echostr)
if ret != 0:
raise Exception(f'验证失败,错误码: {ret}')
return reply_echo_str
elif req.method == 'POST':
encrypt_msg = await req.data
elif request.method == 'POST':
encrypt_msg = await request.data
ret, xml_msg = wxcpt.DecryptMsg(encrypt_msg, msg_signature, timestamp, nonce)
if ret != 0:
raise Exception(f'消息解密失败,错误码: {ret}')

View File

@@ -0,0 +1,229 @@
"""流水线调试 WebSocket 路由
提供基于 WebSocket 的实时双向通信,用于流水线调试。
支持 person 和 group 两种会话类型的隔离。
"""
import asyncio
import logging
import uuid
from datetime import datetime
import quart
from ... import group
from ....service.websocket_pool import WebSocketConnection
logger = logging.getLogger(__name__)
async def handle_client_event(connection: WebSocketConnection, message: dict, ap):
"""处理客户端发送的事件
Args:
connection: WebSocket 连接对象
message: 客户端消息 {'type': 'xxx', 'data': {...}}
ap: Application 实例
"""
event_type = message.get('type')
data = message.get('data', {})
pipeline_uuid = connection.pipeline_uuid
session_type = connection.session_type
try:
webchat_adapter = ap.platform_mgr.webchat_proxy_bot.adapter
if event_type == 'send_message':
# 发送消息到指定会话
message_chain_obj = data.get('message_chain', [])
client_message_id = data.get('client_message_id')
if not message_chain_obj:
await connection.send('error', {'error': 'message_chain is required', 'error_code': 'INVALID_REQUEST'})
return
logger.info(
f"Received send_message: pipeline={pipeline_uuid}, "
f"session={session_type}, "
f"client_msg_id={client_message_id}"
)
# 调用 webchat_adapter.send_webchat_message
# 消息将通过 reply_message_chunk 自动推送到 WebSocket
result = None
async for msg in webchat_adapter.send_webchat_message(
pipeline_uuid=pipeline_uuid, session_type=session_type, message_chain_obj=message_chain_obj, is_stream=True
):
result = msg
# 发送确认
if result:
await connection.send(
'message_sent',
{
'client_message_id': client_message_id,
'server_message_id': result.get('id'),
'timestamp': result.get('timestamp'),
},
)
elif event_type == 'load_history':
# 加载指定会话的历史消息
before_message_id = data.get('before_message_id')
limit = data.get('limit', 50)
logger.info(f"Loading history: pipeline={pipeline_uuid}, session={session_type}, limit={limit}")
# 从对应会话获取历史消息
messages = webchat_adapter.get_webchat_messages(pipeline_uuid, session_type)
# 简单分页:返回最后 limit 条
if before_message_id:
# TODO: 实现基于 message_id 的分页
history_messages = messages[-limit:]
else:
history_messages = messages[-limit:] if len(messages) > limit else messages
await connection.send(
'history', {'messages': history_messages, 'has_more': len(messages) > len(history_messages)}
)
elif event_type == 'interrupt':
# 中断消息
message_id = data.get('message_id')
logger.info(f"Interrupt requested: message_id={message_id}")
# TODO: 实现中断逻辑
await connection.send('interrupted', {'message_id': message_id, 'partial_content': ''})
elif event_type == 'ping':
# 心跳
connection.last_ping = datetime.now()
await connection.send('pong', {'timestamp': data.get('timestamp')})
else:
logger.warning(f"Unknown event type: {event_type}")
await connection.send('error', {'error': f'Unknown event type: {event_type}', 'error_code': 'UNKNOWN_EVENT'})
except Exception as e:
logger.error(f"Error handling event {event_type}: {e}", exc_info=True)
await connection.send(
'error',
{'error': f'Internal server error: {str(e)}', 'error_code': 'INTERNAL_ERROR', 'details': {'event_type': event_type}},
)
@group.group_class('pipeline-websocket', '/api/v1/pipelines/<pipeline_uuid>/chat')
class PipelineWebSocketRouterGroup(group.RouterGroup):
"""流水线调试 WebSocket 路由组"""
async def initialize(self) -> None:
@self.route('/ws')
async def websocket_handler(pipeline_uuid: str):
"""WebSocket 连接处理 - 会话隔离
连接流程:
1. 客户端建立 WebSocket 连接
2. 客户端发送 connect 事件(携带 session_type 和 token
3. 服务端验证并创建连接对象
4. 进入消息循环,处理客户端事件
5. 断开时清理连接
Args:
pipeline_uuid: 流水线 UUID
"""
websocket = quart.websocket._get_current_object()
connection_id = str(uuid.uuid4())
session_key = None
connection = None
try:
# 1. 等待客户端发送 connect 事件
first_message = await websocket.receive_json()
if first_message.get('type') != 'connect':
await websocket.send_json(
{'type': 'error', 'data': {'error': 'First message must be connect event', 'error_code': 'INVALID_HANDSHAKE'}}
)
await websocket.close(1008)
return
connect_data = first_message.get('data', {})
session_type = connect_data.get('session_type')
token = connect_data.get('token')
# 验证参数
if session_type not in ['person', 'group']:
await websocket.send_json(
{'type': 'error', 'data': {'error': 'session_type must be person or group', 'error_code': 'INVALID_SESSION_TYPE'}}
)
await websocket.close(1008)
return
# 验证 token
if not token:
await websocket.send_json(
{'type': 'error', 'data': {'error': 'token is required', 'error_code': 'MISSING_TOKEN'}}
)
await websocket.close(1008)
return
# 验证用户身份
try:
user = await self.ap.user_service.verify_token(token)
if not user:
await websocket.send_json({'type': 'error', 'data': {'error': 'Unauthorized', 'error_code': 'UNAUTHORIZED'}})
await websocket.close(1008)
return
except Exception as e:
logger.error(f"Token verification failed: {e}")
await websocket.send_json(
{'type': 'error', 'data': {'error': 'Token verification failed', 'error_code': 'AUTH_ERROR'}}
)
await websocket.close(1008)
return
# 2. 创建连接对象并加入连接池
connection = WebSocketConnection(
connection_id=connection_id,
websocket=websocket,
pipeline_uuid=pipeline_uuid,
session_type=session_type,
created_at=datetime.now(),
last_ping=datetime.now(),
)
session_key = connection.session_key
ws_pool = self.ap.ws_pool
ws_pool.add_connection(connection)
# 3. 发送连接成功事件
await connection.send(
'connected', {'connection_id': connection_id, 'session_type': session_type, 'pipeline_uuid': pipeline_uuid}
)
logger.info(f"WebSocket connected: {connection_id} [pipeline={pipeline_uuid}, session={session_type}]")
# 4. 进入消息处理循环
while True:
try:
message = await websocket.receive_json()
await handle_client_event(connection, message, self.ap)
except asyncio.CancelledError:
logger.info(f"WebSocket connection cancelled: {connection_id}")
break
except Exception as e:
logger.error(f"Error receiving message from {connection_id}: {e}")
break
except quart.exceptions.WebsocketDisconnected:
logger.info(f"WebSocket disconnected: {connection_id}")
except Exception as e:
logger.error(f"WebSocket error for {connection_id}: {e}", exc_info=True)
finally:
# 清理连接
if connection and session_key:
ws_pool = self.ap.ws_pool
await ws_pool.remove_connection(connection_id, session_key)
logger.info(f"WebSocket connection cleaned up: {connection_id}")

View File

@@ -18,8 +18,7 @@ class BotsRouterGroup(group.RouterGroup):
@self.route('/<bot_uuid>', methods=['GET', 'PUT', 'DELETE'])
async def _(bot_uuid: str) -> str:
if quart.request.method == 'GET':
# 返回运行时信息包括webhook地址等
bot = await self.ap.bot_service.get_runtime_bot_info(bot_uuid)
bot = await self.ap.bot_service.get_bot(bot_uuid)
if bot is None:
return self.http_status(404, -1, 'bot not found')
return self.success(data={'bot': bot})

View File

@@ -1,57 +0,0 @@
from __future__ import annotations
import quart
import traceback
from .. import group
@group.group_class('webhooks', '/bots')
class WebhookRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('/<bot_uuid>', methods=['GET', 'POST'], auth_type=group.AuthType.NONE)
async def handle_webhook(bot_uuid: str):
"""处理 bot webhook 回调(无子路径)"""
return await self._dispatch_webhook(bot_uuid, '')
@self.route('/<bot_uuid>/<path:path>', methods=['GET', 'POST'], auth_type=group.AuthType.NONE)
async def handle_webhook_with_path(bot_uuid: str, path: str):
"""处理 bot webhook 回调(带子路径)"""
return await self._dispatch_webhook(bot_uuid, path)
async def _dispatch_webhook(self, bot_uuid: str, path: str):
"""分发 webhook 请求到对应的 bot adapter
Args:
bot_uuid: Bot 的 UUID
path: 子路径(如果有的话)
Returns:
适配器返回的响应
"""
try:
runtime_bot = await self.ap.platform_mgr.get_bot_by_uuid(bot_uuid)
if not runtime_bot:
return quart.jsonify({'error': 'Bot not found'}), 404
if not runtime_bot.enable:
return quart.jsonify({'error': 'Bot is disabled'}), 403
if not hasattr(runtime_bot.adapter, 'handle_unified_webhook'):
return quart.jsonify({'error': 'Adapter does not support unified webhook'}), 501
response = await runtime_bot.adapter.handle_unified_webhook(
bot_uuid=bot_uuid,
path=path,
request=quart.request,
)
return response
except Exception as e:
self.ap.logger.error(f'Webhook dispatch error for bot {bot_uuid}: {traceback.format_exc()}')
return quart.jsonify({'error': str(e)}), 500

View File

@@ -58,15 +58,6 @@ class BotService:
if runtime_bot is not None:
adapter_runtime_values['bot_account_id'] = runtime_bot.adapter.bot_account_id
if persistence_bot['adapter'] in ['wecom', 'wecombot', 'officialaccount', 'qqofficial', 'slack','wecomcs']:
api_port = self.ap.instance_config.data['api']['port']
webhook_url = f"/bots/{bot_uuid}"
adapter_runtime_values['webhook_url'] = webhook_url
adapter_runtime_values['webhook_full_url'] = f"http://<Your-Server-IP>:{api_port}{webhook_url}"
else:
adapter_runtime_values['webhook_url'] = None
adapter_runtime_values['webhook_full_url'] = None
persistence_bot['adapter_runtime_values'] = adapter_runtime_values
return persistence_bot

View File

@@ -0,0 +1,211 @@
"""WebSocket 连接池管理
用于管理流水线调试的 WebSocket 连接,支持会话隔离和消息广播。
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from dataclasses import dataclass, field
from datetime import datetime
from typing import Optional
import quart
logger = logging.getLogger(__name__)
@dataclass
class WebSocketConnection:
"""单个 WebSocket 连接"""
connection_id: str
websocket: quart.websocket.WebSocket
pipeline_uuid: str
session_type: str # 'person' 或 'group'
created_at: datetime
last_ping: datetime
@property
def session_key(self) -> str:
"""会话唯一标识: pipeline_uuid:session_type"""
return f"{self.pipeline_uuid}:{self.session_type}"
async def send(self, event_type: str, data: dict):
"""发送事件到客户端
Args:
event_type: 事件类型
data: 事件数据
"""
try:
await self.websocket.send_json({"type": event_type, "data": data})
except Exception as e:
logger.error(f"Failed to send message to {self.connection_id}: {e}")
raise
class WebSocketConnectionPool:
"""WebSocket 连接池 - 按会话隔离
连接池结构:
connections[session_key][connection_id] = WebSocketConnection
其中 session_key = f"{pipeline_uuid}:{session_type}"
这样可以确保:
- person 和 group 会话完全隔离
- 不同 pipeline 的会话隔离
- 同一会话的多个连接可以同步接收消息(多标签页)
"""
def __init__(self):
self.connections: dict[str, dict[str, WebSocketConnection]] = {}
self._lock = asyncio.Lock()
def add_connection(self, conn: WebSocketConnection):
"""添加连接到指定会话
Args:
conn: WebSocket 连接对象
"""
session_key = conn.session_key
if session_key not in self.connections:
self.connections[session_key] = {}
self.connections[session_key][conn.connection_id] = conn
logger.info(
f"WebSocket connection added: {conn.connection_id} "
f"to session {session_key} "
f"(total: {len(self.connections[session_key])} connections)"
)
async def remove_connection(self, connection_id: str, session_key: str):
"""从指定会话移除连接
Args:
connection_id: 连接 ID
session_key: 会话标识
"""
async with self._lock:
if session_key in self.connections:
conn = self.connections[session_key].pop(connection_id, None)
# 如果该会话没有连接了,清理会话
if not self.connections[session_key]:
del self.connections[session_key]
if conn:
logger.info(
f"WebSocket connection removed: {connection_id} "
f"from session {session_key} "
f"(remaining: {len(self.connections.get(session_key, {}))} connections)"
)
def get_connection(self, connection_id: str, session_key: str) -> Optional[WebSocketConnection]:
"""获取指定连接
Args:
connection_id: 连接 ID
session_key: 会话标识
Returns:
WebSocketConnection 或 None
"""
return self.connections.get(session_key, {}).get(connection_id)
def get_connections_by_session(self, pipeline_uuid: str, session_type: str) -> list[WebSocketConnection]:
"""获取指定会话的所有连接
Args:
pipeline_uuid: 流水线 UUID
session_type: 会话类型 ('person''group')
Returns:
连接列表
"""
session_key = f"{pipeline_uuid}:{session_type}"
return list(self.connections.get(session_key, {}).values())
async def broadcast_to_session(self, pipeline_uuid: str, session_type: str, event_type: str, data: dict):
"""广播消息到指定会话的所有连接
Args:
pipeline_uuid: 流水线 UUID
session_type: 会话类型 ('person''group')
event_type: 事件类型
data: 事件数据
"""
connections = self.get_connections_by_session(pipeline_uuid, session_type)
if not connections:
logger.debug(f"No connections for session {pipeline_uuid}:{session_type}, skipping broadcast")
return
logger.debug(
f"Broadcasting {event_type} to session {pipeline_uuid}:{session_type}, " f"{len(connections)} connections"
)
# 并发发送到所有连接,忽略失败的连接
results = await asyncio.gather(*[conn.send(event_type, data) for conn in connections], return_exceptions=True)
# 统计失败的连接
failed_count = sum(1 for result in results if isinstance(result, Exception))
if failed_count > 0:
logger.warning(f"Failed to send to {failed_count}/{len(connections)} connections")
def get_all_sessions(self) -> list[str]:
"""获取所有活跃会话的 session_key 列表
Returns:
会话标识列表
"""
return list(self.connections.keys())
def get_connection_count(self, pipeline_uuid: str, session_type: str) -> int:
"""获取指定会话的连接数量
Args:
pipeline_uuid: 流水线 UUID
session_type: 会话类型
Returns:
连接数量
"""
session_key = f"{pipeline_uuid}:{session_type}"
return len(self.connections.get(session_key, {}))
async def cleanup_stale_connections(self, timeout_seconds: int = 120):
"""清理超时的连接
Args:
timeout_seconds: 超时时间(秒)
"""
now = datetime.now()
stale_connections = []
# 查找超时连接
for session_key, session_conns in self.connections.items():
for conn_id, conn in session_conns.items():
elapsed = (now - conn.last_ping).total_seconds()
if elapsed > timeout_seconds:
stale_connections.append((conn_id, session_key))
# 移除超时连接
for conn_id, session_key in stale_connections:
logger.warning(f"Removing stale connection: {conn_id} from {session_key}")
await self.remove_connection(conn_id, session_key)
# 尝试关闭 WebSocket
try:
conn = self.get_connection(conn_id, session_key)
if conn:
await conn.websocket.close(1000, "Connection timeout")
except Exception as e:
logger.error(f"Error closing stale connection {conn_id}: {e}")
if stale_connections:
logger.info(f"Cleaned up {len(stale_connections)} stale connections")

View File

@@ -105,6 +105,10 @@ class Application:
storage_mgr: storagemgr.StorageMgr = None
# ========= WebSocket =========
ws_pool = None # WebSocketConnectionPool
# ========= HTTP Services =========
user_service: user_service.UserService = None

View File

@@ -19,6 +19,7 @@ from ...api.http.service import model as model_service
from ...api.http.service import pipeline as pipeline_service
from ...api.http.service import bot as bot_service
from ...api.http.service import knowledge as knowledge_service
from ...api.http.service import websocket_pool
from ...discover import engine as discover_engine
from ...storage import mgr as storagemgr
from ...utils import logcache
@@ -87,10 +88,18 @@ class BuildAppStage(stage.BootingStage):
await llm_tool_mgr_inst.initialize()
ap.tool_mgr = llm_tool_mgr_inst
# Initialize WebSocket connection pool
ws_pool_inst = websocket_pool.WebSocketConnectionPool()
ap.ws_pool = ws_pool_inst
im_mgr_inst = im_mgr.PlatformManager(ap=ap)
await im_mgr_inst.initialize()
ap.platform_mgr = im_mgr_inst
# Inject WebSocket pool into WebChatAdapter
if hasattr(ap.platform_mgr, 'webchat_proxy_bot') and ap.platform_mgr.webchat_proxy_bot:
ap.platform_mgr.webchat_proxy_bot.adapter.set_ws_pool(ws_pool_inst)
pipeline_mgr = pipelinemgr.PipelineManager(ap)
await pipeline_mgr.initialize()
ap.pipeline_mgr = pipeline_mgr

View File

@@ -232,10 +232,6 @@ class PlatformManager:
logger,
)
# 如果 adapter 支持 set_bot_uuid 方法,设置 bot_uuid用于统一 webhook
if hasattr(adapter_inst, 'set_bot_uuid'):
adapter_inst.set_bot_uuid(bot_entity.uuid)
runtime_bot = RuntimeBot(ap=self.ap, bot_entity=bot_entity, adapter=adapter_inst, logger=logger)
await runtime_bot.initialize()

View File

@@ -59,16 +59,14 @@ class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAd
message_converter: OAMessageConverter = OAMessageConverter()
event_converter: OAEventConverter = OAEventConverter()
bot: typing.Union[OAClient, OAClientForLongerResponse] = pydantic.Field(exclude=True)
bot_uuid: str = None
def __init__(self, config: dict, logger: EventLogger):
# 校验必填项
required_keys = ['token', 'EncodingAESKey', 'AppSecret', 'AppID', 'Mode']
missing_keys = [k for k in required_keys if k not in config]
if missing_keys:
raise Exception(f'OfficialAccount 缺少配置项: {missing_keys}')
# 创建运行时 bot 对象,始终使用统一 webhook 模式
if config['Mode'] == 'drop':
bot = OAClient(
token=config['token'],
@@ -76,7 +74,6 @@ class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAd
Appsecret=config['AppSecret'],
AppID=config['AppID'],
logger=logger,
unified_mode=True,
)
elif config['Mode'] == 'passive':
bot = OAClientForLongerResponse(
@@ -86,14 +83,13 @@ class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAd
AppID=config['AppID'],
LoadingMessage=config.get('LoadingMessage', ''),
logger=logger,
unified_mode=True,
)
else:
raise KeyError('请设置微信公众号通信模式')
bot_account_id = config.get('AppID', '')
super().__init__(
bot=bot,
bot_account_id=bot_account_id,
@@ -140,45 +136,16 @@ class OfficialAccountAdapter(abstract_platform_adapter.AbstractMessagePlatformAd
elif event_type == platform_events.GroupMessage:
pass
def set_bot_uuid(self, bot_uuid: str):
"""设置 bot UUID用于生成 webhook URL"""
self.bot_uuid = bot_uuid
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
"""处理统一 webhook 请求。
Args:
bot_uuid: Bot 的 UUID
path: 子路径(如果有的话)
request: Quart Request 对象
Returns:
响应数据
"""
return await self.bot.handle_unified_webhook(request)
async def run_async(self):
# 统一 webhook 模式下,不启动独立的 Quart 应用
# 保持运行但不启动独立端口
# 打印 webhook 回调地址
if self.bot_uuid and hasattr(self.logger, 'ap'):
try:
api_port = self.logger.ap.instance_config.data['api']['port']
webhook_url = f"http://127.0.0.1:{api_port}/bots/{self.bot_uuid}"
webhook_url_public = f"http://<Your-Public-IP>:{api_port}/bots/{self.bot_uuid}"
await self.logger.info(f"微信公众号 Webhook 回调地址:")
await self.logger.info(f" 本地地址: {webhook_url}")
await self.logger.info(f" 公网地址: {webhook_url_public}")
await self.logger.info(f"请在微信公众号后台配置此回调地址")
except Exception as e:
await self.logger.warning(f"无法生成 webhook URL: {e}")
async def keep_alive():
async def shutdown_trigger_placeholder():
while True:
await asyncio.sleep(1)
await keep_alive()
await self.bot.run_task(
host=self.config['host'],
port=self.config['port'],
shutdown_trigger=shutdown_trigger_placeholder,
)
async def kill(self) -> bool:
return False

View File

@@ -53,6 +53,23 @@ spec:
type: string
required: true
default: "AI正在思考中请发送任意内容获取回复。"
- name: host
label:
en_US: Host
zh_Hans: 监听主机
description:
en_US: The host that Official Account listens on for Webhook connections.
zh_Hans: 微信公众号监听的主机,除非你知道自己在做什么,否则请写 0.0.0.0
type: string
required: true
default: 0.0.0.0
- name: port
label:
en_US: Port
zh_Hans: 监听端口
type: integer
required: true
default: 2287
execution:
python:
path: ./officialaccount.py

View File

@@ -135,17 +135,12 @@ class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
bot: QQOfficialClient
config: dict
bot_account_id: str
bot_uuid: str = None
message_converter: QQOfficialMessageConverter = QQOfficialMessageConverter()
event_converter: QQOfficialEventConverter = QQOfficialEventConverter()
def __init__(self, config: dict, logger: EventLogger):
bot = QQOfficialClient(
app_id=config['appid'],
secret=config['secret'],
token=config['token'],
logger=logger,
unified_mode=True
app_id=config['appid'], secret=config['secret'], token=config['token'], logger=logger
)
super().__init__(
@@ -231,45 +226,16 @@ class QQOfficialAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter
self.bot.on_message('GROUP_AT_MESSAGE_CREATE')(on_message)
self.bot.on_message('AT_MESSAGE_CREATE')(on_message)
def set_bot_uuid(self, bot_uuid: str):
"""设置 bot UUID用于生成 webhook URL"""
self.bot_uuid = bot_uuid
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
"""处理统一 webhook 请求。
Args:
bot_uuid: Bot 的 UUID
path: 子路径(如果有的话)
request: Quart Request 对象
Returns:
响应数据
"""
return await self.bot.handle_unified_webhook(request)
async def run_async(self):
# 统一 webhook 模式下,不启动独立的 Quart 应用
# 保持运行但不启动独立端口
# 打印 webhook 回调地址
if self.bot_uuid and hasattr(self.logger, 'ap'):
try:
api_port = self.logger.ap.instance_config.data['api']['port']
webhook_url = f"http://127.0.0.1:{api_port}/bots/{self.bot_uuid}"
webhook_url_public = f"http://<Your-Public-IP>:{api_port}/bots/{self.bot_uuid}"
await self.logger.info(f"QQ 官方机器人 Webhook 回调地址:")
await self.logger.info(f" 本地地址: {webhook_url}")
await self.logger.info(f" 公网地址: {webhook_url_public}")
await self.logger.info(f"请在 QQ 官方机器人后台配置此回调地址")
except Exception as e:
await self.logger.warning(f"无法生成 webhook URL: {e}")
async def keep_alive():
async def shutdown_trigger_placeholder():
while True:
await asyncio.sleep(1)
await keep_alive()
await self.bot.run_task(
host='0.0.0.0',
port=self.config['port'],
shutdown_trigger=shutdown_trigger_placeholder,
)
async def kill(self) -> bool:
return False

View File

@@ -25,6 +25,13 @@ spec:
type: string
required: true
default: ""
- name: port
label:
en_US: Port
zh_Hans: 监听端口
type: integer
required: true
default: 2284
- name: token
label:
en_US: Token

View File

@@ -86,12 +86,13 @@ class SlackEventConverter(abstract_platform_adapter.AbstractEventConverter):
class SlackAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: SlackClient
bot_account_id: str
bot_uuid: str = None
message_converter: SlackMessageConverter = SlackMessageConverter()
event_converter: SlackEventConverter = SlackEventConverter()
config: dict
def __init__(self, config: dict, logger: EventLogger):
self.config = config
self.logger = logger
required_keys = [
'bot_token',
'signing_secret',
@@ -100,18 +101,8 @@ class SlackAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
if missing_keys:
raise command_errors.ParamNotEnoughError('Slack机器人缺少相关配置项请查看文档或联系管理员')
bot = SlackClient(
bot_token=config['bot_token'],
signing_secret=config['signing_secret'],
logger=logger,
unified_mode=True
)
super().__init__(
config=config,
logger=logger,
bot=bot,
bot_account_id=config['bot_token'],
self.bot = SlackClient(
bot_token=self.config['bot_token'], signing_secret=self.config['signing_secret'], logger=self.logger
)
async def reply_message(
@@ -157,45 +148,16 @@ class SlackAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
elif event_type == platform_events.GroupMessage:
self.bot.on_message('channel')(on_message)
def set_bot_uuid(self, bot_uuid: str):
"""设置 bot UUID用于生成 webhook URL"""
self.bot_uuid = bot_uuid
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
"""处理统一 webhook 请求。
Args:
bot_uuid: Bot 的 UUID
path: 子路径(如果有的话)
request: Quart Request 对象
Returns:
响应数据
"""
return await self.bot.handle_unified_webhook(request)
async def run_async(self):
# 统一 webhook 模式下,不启动独立的 Quart 应用
# 保持运行但不启动独立端口
# 打印 webhook 回调地址
if self.bot_uuid and hasattr(self.logger, 'ap'):
try:
api_port = self.logger.ap.instance_config.data['api']['port']
webhook_url = f"http://127.0.0.1:{api_port}/bots/{self.bot_uuid}"
webhook_url_public = f"http://<Your-Public-IP>:{api_port}/bots/{self.bot_uuid}"
await self.logger.info(f"Slack 机器人 Webhook 回调地址:")
await self.logger.info(f" 本地地址: {webhook_url}")
await self.logger.info(f" 公网地址: {webhook_url_public}")
await self.logger.info(f"请在 Slack 后台配置此回调地址")
except Exception as e:
await self.logger.warning(f"无法生成 webhook URL: {e}")
async def keep_alive():
async def shutdown_trigger_placeholder():
while True:
await asyncio.sleep(1)
await keep_alive()
await self.bot.run_task(
host='0.0.0.0',
port=self.config['port'],
shutdown_trigger=shutdown_trigger_placeholder,
)
async def kill(self) -> bool:
return False

View File

@@ -25,6 +25,13 @@ spec:
type: string
required: true
default: ""
- name: port
label:
en_US: Port
zh_Hans: 监听端口
type: int
required: true
default: 2288
execution:
python:
path: ./slack.py

View File

@@ -58,6 +58,7 @@ class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
debug_messages: dict[str, list[dict]] = pydantic.Field(default_factory=dict, exclude=True)
ap: app.Application = pydantic.Field(exclude=True)
ws_pool: typing.Any = pydantic.Field(exclude=True, default=None) # WebSocketConnectionPool
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger, **kwargs):
super().__init__(
@@ -72,6 +73,15 @@ class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
self.bot_account_id = 'webchatbot'
self.debug_messages = {}
self.ws_pool = None
def set_ws_pool(self, ws_pool):
"""设置 WebSocket 连接池
Args:
ws_pool: WebSocketConnectionPool 实例
"""
self.ws_pool = ws_pool
async def send_message(
self,
@@ -130,7 +140,7 @@ class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
quote_origin: bool = False,
is_final: bool = False,
) -> dict:
"""回复消息"""
"""Reply message chunk - supports both SSE (legacy) and WebSocket"""
message_data = WebChatMessage(
id=-1,
role='assistant',
@@ -139,24 +149,32 @@ class WebChatAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
timestamp=datetime.now().isoformat(),
)
# notify waiter
session = (
self.webchat_group_session
if isinstance(message_source, platform_events.GroupMessage)
else self.webchat_person_session
)
if message_source.message_chain.message_id not in session.resp_waiters:
# session.resp_waiters[message_source.message_chain.message_id] = asyncio.Queue()
queue = session.resp_queues[message_source.message_chain.message_id]
# Determine session type
if isinstance(message_source, platform_events.GroupMessage):
session_type = 'group'
session = self.webchat_group_session
else: # FriendMessage
session_type = 'person'
session = self.webchat_person_session
# if isinstance(message_source, platform_events.FriendMessage):
# queue = self.webchat_person_session.resp_queues[message_source.message_chain.message_id]
# elif isinstance(message_source, platform_events.GroupMessage):
# queue = self.webchat_group_session.resp_queues[message_source.message_chain.message_id]
if is_final and bot_message.tool_calls is None:
message_data.is_final = True
# print(message_data)
await queue.put(message_data)
# Legacy SSE support: put message into queue
if message_source.message_chain.message_id in session.resp_queues:
queue = session.resp_queues[message_source.message_chain.message_id]
if is_final and bot_message.tool_calls is None:
message_data.is_final = True
await queue.put(message_data)
# WebSocket support: broadcast to all connections
if self.ws_pool:
pipeline_uuid = self.ap.platform_mgr.webchat_proxy_bot.bot_entity.use_pipeline_uuid
# Determine event type
event_type = 'message_complete' if (is_final and bot_message.tool_calls is None) else 'message_chunk'
# Broadcast to specified session only
await self.ws_pool.broadcast_to_session(
pipeline_uuid=pipeline_uuid, session_type=session_type, event_type=event_type, data=message_data.model_dump()
)
return message_data.model_dump()

View File

@@ -132,7 +132,6 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
message_converter: WecomMessageConverter = WecomMessageConverter()
event_converter: WecomEventConverter = WecomEventConverter()
config: dict
bot_uuid: str = None
def __init__(self, config: dict, logger: EventLogger):
# 校验必填项
@@ -143,12 +142,11 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
'EncodingAESKey',
'contacts_secret',
]
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise Exception(f'Wecom 缺少配置项: {missing_keys}')
# 创建运行时 bot 对象,始终使用统一 webhook 模式
# 创建运行时 bot 对象
bot = WecomClient(
corpid=config['corpid'],
secret=config['secret'],
@@ -156,10 +154,9 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
EncodingAESKey=config['EncodingAESKey'],
contacts_secret=config['contacts_secret'],
logger=logger,
unified_mode=True,
)
super().__init__(
config=config,
logger=logger,
@@ -167,9 +164,6 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot_account_id="",
)
def set_bot_uuid(self, bot_uuid: str):
"""设置 bot UUID用于生成 webhook URL"""
self.bot_uuid = bot_uuid
async def reply_message(
self,
@@ -189,7 +183,9 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
await self.bot.send_image(fixed_user_id, Wecom_event.agent_id, content['media_id'])
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
"""企业微信目前只有发送给个人的方法,
构造target_id的方式为前半部分为账户id后半部分为agent_id,中间使用“|”符号隔开。
"""
content_list = await WecomMessageConverter.yiri2target(message, self.bot)
parts = target_id.split('|')
user_id = parts[0]
@@ -221,38 +217,16 @@ class WecomAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
elif event_type == platform_events.GroupMessage:
pass
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
"""处理统一 webhook 请求。
Args:
bot_uuid: Bot 的 UUID
path: 子路径(如果有的话)
request: Quart Request 对象
Returns:
响应数据
"""
return await self.bot.handle_unified_webhook(request)
async def run_async(self):
if self.bot_uuid and hasattr(self.logger, 'ap'):
try:
api_port = self.logger.ap.instance_config.data['api']['port']
webhook_url = f"http://127.0.0.1:{api_port}/bots/{self.bot_uuid}"
webhook_url_public = f"http://<Your-Public-IP>:{api_port}/bots/{self.bot_uuid}"
await self.logger.info(f"企业微信 Webhook 回调地址:")
await self.logger.info(f" 本地地址: {webhook_url}")
await self.logger.info(f" 公网地址: {webhook_url_public}")
await self.logger.info(f"请在企业微信后台配置此回调地址")
except Exception as e:
await self.logger.warning(f"无法生成 webhook URL: {e}")
async def keep_alive():
async def shutdown_trigger_placeholder():
while True:
await asyncio.sleep(1)
await keep_alive()
await self.bot.run_task(
host=self.config['host'],
port=self.config['port'],
shutdown_trigger=shutdown_trigger_placeholder,
)
async def kill(self) -> bool:
return False

View File

@@ -11,6 +11,23 @@ metadata:
icon: wecom.png
spec:
config:
- name: host
label:
en_US: Host
zh_Hans: 监听主机
description:
en_US: Webhook host, unless you know what you're doing, please write 0.0.0.0
zh_Hans: Webhook 监听主机,除非你知道自己在做什么,否则请写 0.0.0.0
type: string
required: true
default: "0.0.0.0"
- name: port
label:
en_US: Port
zh_Hans: 监听端口
type: integer
required: true
default: 2290
- name: corpid
label:
en_US: Corpid

View File

@@ -49,7 +49,7 @@ class WecomBotEventConverter(abstract_platform_adapter.AbstractEventConverter):
return platform_events.FriendMessage(
sender=platform_entities.Friend(
id=event.userid,
nickname=event.username,
nickname='',
remark='',
),
message_chain=message_chain,
@@ -61,10 +61,10 @@ class WecomBotEventConverter(abstract_platform_adapter.AbstractEventConverter):
sender = platform_entities.GroupMember(
id=event.userid,
permission='MEMBER',
member_name=event.username,
member_name=event.userid,
group=platform_entities.Group(
id=str(event.chatid),
name=event.chatname,
name='',
permission=platform_entities.Permission.Member,
),
special_title='',
@@ -88,22 +88,19 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
message_converter: WecomBotMessageConverter = WecomBotMessageConverter()
event_converter: WecomBotEventConverter = WecomBotEventConverter()
config: dict
bot_uuid: str = None
def __init__(self, config: dict, logger: EventLogger):
required_keys = ['Token', 'EncodingAESKey', 'Corpid', 'BotId']
required_keys = ['Token', 'EncodingAESKey', 'Corpid', 'BotId', 'port']
missing_keys = [key for key in required_keys if key not in config]
if missing_keys:
raise Exception(f'WecomBot 缺少配置项: {missing_keys}')
# 创建运行时 bot 对象
bot = WecomBotClient(
Token=config['Token'],
EnCodingAESKey=config['EncodingAESKey'],
Corpid=config['Corpid'],
logger=logger,
unified_mode=True,
)
bot_account_id = config['BotId']
@@ -120,50 +117,6 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
content = await self.message_converter.yiri2target(message)
await self.bot.set_message(message_source.source_platform_object.message_id, content)
async def reply_message_chunk(
self,
message_source: platform_events.MessageEvent,
bot_message,
message: platform_message.MessageChain,
quote_origin: bool = False,
is_final: bool = False,
):
"""将流水线增量输出写入企业微信 stream 会话。
Args:
message_source: 流水线提供的原始消息事件。
bot_message: 当前片段对应的模型元信息(未使用)。
message: 需要回复的消息链。
quote_origin: 是否引用原消息(企业微信暂不支持)。
is_final: 标记当前片段是否为最终回复。
Returns:
dict: 包含 `stream` 键,标识写入是否成功。
Example:
在流水线 `reply_message_chunk` 调用中自动触发,无需手动调用。
"""
# 转换为纯文本(智能机器人当前协议仅支持文本流)
content = await self.message_converter.yiri2target(message)
msg_id = message_source.source_platform_object.message_id
# 将片段推送到 WecomBotClient 中的队列,返回值用于判断是否走降级逻辑
success = await self.bot.push_stream_chunk(msg_id, content, is_final=is_final)
if not success and is_final:
# 未命中流式队列时使用旧有 set_message 兜底
await self.bot.set_message(msg_id, content)
return {'stream': success}
async def is_stream_output_supported(self) -> bool:
"""智能机器人侧默认开启流式能力。
Returns:
bool: 恒定返回 True。
Example:
流水线执行阶段会调用此方法以确认是否启用流式。"""
return True
async def send_message(self, target_type, target_id, message):
pass
@@ -185,46 +138,18 @@ class WecomBotAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
self.bot.on_message('group')(on_message)
except Exception:
print(traceback.format_exc())
def set_bot_uuid(self, bot_uuid: str):
"""设置 bot UUID用于生成 webhook URL"""
self.bot_uuid = bot_uuid
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
"""处理统一 webhook 请求。
Args:
bot_uuid: Bot 的 UUID
path: 子路径(如果有的话)
request: Quart Request 对象
Returns:
响应数据
"""
return await self.bot.handle_unified_webhook(request)
async def run_async(self):
# 统一 webhook 模式下,不启动独立的 Quart 应用
# 保持运行但不启动独立端口
# 打印 webhook 回调地址
if self.bot_uuid and hasattr(self.logger, 'ap'):
try:
api_port = self.logger.ap.instance_config.data['api']['port']
webhook_url = f"http://127.0.0.1:{api_port}/bots/{self.bot_uuid}"
webhook_url_public = f"http://<Your-Public-IP>:{api_port}/bots/{self.bot_uuid}"
await self.logger.info(f"企业微信机器人 Webhook 回调地址:")
await self.logger.info(f" 本地地址: {webhook_url}")
await self.logger.info(f" 公网地址: {webhook_url_public}")
await self.logger.info(f"请在企业微信后台配置此回调地址")
except Exception as e:
await self.logger.warning(f"无法生成 webhook URL: {e}")
async def keep_alive():
async def shutdown_trigger_placeholder():
while True:
await asyncio.sleep(1)
await keep_alive()
await self.bot.run_task(
host='0.0.0.0',
port=self.config['port'],
shutdown_trigger=shutdown_trigger_placeholder,
)
async def kill(self) -> bool:
return False

View File

@@ -11,6 +11,13 @@ metadata:
icon: wecombot.png
spec:
config:
- name: port
label:
en_US: Port
zh_Hans: 监听端口
type: integer
required: true
default: 2291
- name: Corpid
label:
en_US: Corpid

View File

@@ -121,7 +121,6 @@ class WecomCSAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
bot: WecomCSClient = pydantic.Field(exclude=True)
message_converter: WecomMessageConverter = WecomMessageConverter()
event_converter: WecomEventConverter = WecomEventConverter()
bot_uuid: str = None
def __init__(self, config: dict, logger: abstract_platform_logger.AbstractEventLogger):
required_keys = [
@@ -140,7 +139,6 @@ class WecomCSAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
token=config['token'],
EncodingAESKey=config['EncodingAESKey'],
logger=logger,
unified_mode=True,
)
super().__init__(
@@ -172,10 +170,6 @@ class WecomCSAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
pass
def set_bot_uuid(self, bot_uuid: str):
"""设置 bot UUID用于生成 webhook URL"""
self.bot_uuid = bot_uuid
def register_listener(
self,
event_type: typing.Type[platform_events.Event],
@@ -196,41 +190,16 @@ class WecomCSAdapter(abstract_platform_adapter.AbstractMessagePlatformAdapter):
elif event_type == platform_events.GroupMessage:
pass
async def handle_unified_webhook(self, bot_uuid: str, path: str, request):
"""处理统一 webhook 请求。
Args:
bot_uuid: Bot 的 UUID
path: 子路径(如果有的话)
request: Quart Request 对象
Returns:
响应数据
"""
return await self.bot.handle_unified_webhook(request)
async def run_async(self):
# 统一 webhook 模式下,不启动独立的 Quart 应用
# 保持运行但不启动独立端口
# 打印 webhook 回调地址
if self.bot_uuid and hasattr(self.logger, 'ap'):
try:
api_port = self.logger.ap.instance_config.data['api']['port']
webhook_url = f"http://127.0.0.1:{api_port}/bots/{self.bot_uuid}"
webhook_url_public = f"http://<Your-Public-IP>:{api_port}/bots/{self.bot_uuid}"
await self.logger.info(f"企业微信客服 Webhook 回调地址:")
await self.logger.info(f" 本地地址: {webhook_url}")
await self.logger.info(f" 公网地址: {webhook_url_public}")
await self.logger.info(f"请在企业微信后台配置此回调地址")
except Exception as e:
await self.logger.warning(f"无法生成 webhook URL: {e}")
async def keep_alive():
async def shutdown_trigger_placeholder():
while True:
await asyncio.sleep(1)
await keep_alive()
await self.bot.run_task(
host='0.0.0.0',
port=self.config['port'],
shutdown_trigger=shutdown_trigger_placeholder,
)
async def kill(self) -> bool:
return False

View File

@@ -11,6 +11,13 @@ metadata:
icon: wecom.png
spec:
config:
- name: port
label:
en_US: Port
zh_Hans: 监听端口
type: int
required: true
default: 2289
- name: corpid
label:
en_US: Corpid

View File

@@ -4,7 +4,6 @@ import json
from typing import List
from pkg.rag.knowledge.services import base_service
from pkg.core import app
from langchain_text_splitters import RecursiveCharacterTextSplitter
class Chunker(base_service.BaseService):
@@ -28,6 +27,21 @@ class Chunker(base_service.BaseService):
"""
if not text:
return []
# words = text.split()
# chunks = []
# current_chunk = []
# for word in words:
# current_chunk.append(word)
# if len(current_chunk) > self.chunk_size:
# chunks.append(" ".join(current_chunk[:self.chunk_size]))
# current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:]
# if current_chunk:
# chunks.append(" ".join(current_chunk))
# A more robust chunking strategy (e.g., using recursive character text splitter)
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,

View File

@@ -60,7 +60,6 @@ dependencies = [
"ebooklib>=0.18",
"html2text>=2024.2.26",
"langchain>=0.2.0",
"langchain-text-splitters>=0.0.1",
"chromadb>=0.4.24",
"qdrant-client (>=1.15.1,<2.0.0)",
"langbot-plugin==0.1.4",

View File

@@ -1,4 +1,4 @@
import React, { useEffect, useState } from 'react';
import { useEffect, useState } from 'react';
import {
IChooseAdapterEntity,
IPipelineEntity,
@@ -112,87 +112,12 @@ export default function BotForm({
IDynamicFormItemSchema[]
>([]);
const [, setIsLoading] = useState<boolean>(false);
const [webhookUrl, setWebhookUrl] = useState<string>('');
const webhookInputRef = React.useRef<HTMLInputElement>(null);
useEffect(() => {
setBotFormValues();
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
// 复制到剪贴板的辅助函数 - 使用页面上的真实input元素
const copyToClipboard = () => {
console.log('[Copy] Attempting to copy from input element');
const inputElement = webhookInputRef.current;
if (!inputElement) {
console.error('[Copy] Input element not found');
toast.error(t('common.copyFailed'));
return;
}
try {
// 确保input元素可见且未被禁用
inputElement.disabled = false;
inputElement.readOnly = false;
// 聚焦并选中所有文本
inputElement.focus();
inputElement.select();
// 尝试使用现代API
if (navigator.clipboard && navigator.clipboard.writeText) {
console.log(
'[Copy] Using Clipboard API with input value:',
inputElement.value,
);
navigator.clipboard
.writeText(inputElement.value)
.then(() => {
console.log('[Copy] Clipboard API success');
inputElement.blur(); // 取消选中
inputElement.readOnly = true;
toast.success(t('bots.webhookUrlCopied'));
})
.catch((err) => {
console.error(
'[Copy] Clipboard API failed, trying execCommand:',
err,
);
// 降级到execCommand
const successful = document.execCommand('copy');
console.log('[Copy] execCommand result:', successful);
inputElement.blur();
inputElement.readOnly = true;
if (successful) {
toast.success(t('bots.webhookUrlCopied'));
} else {
toast.error(t('common.copyFailed'));
}
});
} else {
// 直接使用execCommand
console.log(
'[Copy] Using execCommand with input value:',
inputElement.value,
);
const successful = document.execCommand('copy');
console.log('[Copy] execCommand result:', successful);
inputElement.blur();
inputElement.readOnly = true;
if (successful) {
toast.success(t('bots.webhookUrlCopied'));
} else {
toast.error(t('common.copyFailed'));
}
}
} catch (err) {
console.error('[Copy] Copy failed:', err);
inputElement.readOnly = true;
toast.error(t('common.copyFailed'));
}
};
function setBotFormValues() {
initBotFormComponent().then(() => {
// 拉取初始化表单信息
@@ -208,20 +133,12 @@ export default function BotForm({
console.log('form', form.getValues());
handleAdapterSelect(val.adapter);
// dynamicForm.setFieldsValue(val.adapter_config);
// 设置 webhook 地址(如果有)
if (val.webhook_full_url) {
setWebhookUrl(val.webhook_full_url);
} else {
setWebhookUrl('');
}
})
.catch((err) => {
toast.error(t('bots.getBotConfigError') + err.message);
});
} else {
form.reset();
setWebhookUrl('');
}
});
}
@@ -296,7 +213,7 @@ export default function BotForm({
async function getBotConfig(
botId: string,
): Promise<z.infer<typeof formSchema> & { webhook_full_url?: string }> {
): Promise<z.infer<typeof formSchema>> {
return new Promise((resolve, reject) => {
httpClient
.getBot(botId)
@@ -309,10 +226,6 @@ export default function BotForm({
adapter_config: bot.adapter_config,
enable: bot.enable ?? true,
use_pipeline_uuid: bot.use_pipeline_uuid ?? '',
webhook_full_url: bot.adapter_runtime_values
? ((bot.adapter_runtime_values as Record<string, unknown>)
.webhook_full_url as string)
: undefined,
});
})
.catch((err) => {
@@ -456,86 +369,51 @@ export default function BotForm({
<div className="space-y-4">
{/* 是否启用 & 绑定流水线 仅在编辑模式 */}
{initBotId && (
<>
<div className="flex items-center gap-6">
<FormField
control={form.control}
name="enable"
render={({ field }) => (
<FormItem className="flex flex-col justify-start gap-[0.8rem] h-[3.8rem]">
<FormLabel>{t('common.enable')}</FormLabel>
<FormControl>
<Switch
checked={field.value}
onCheckedChange={field.onChange}
/>
</FormControl>
</FormItem>
)}
/>
<div className="flex items-center gap-6">
<FormField
control={form.control}
name="enable"
render={({ field }) => (
<FormItem className="flex flex-col justify-start gap-[0.8rem] h-[3.8rem]">
<FormLabel>{t('common.enable')}</FormLabel>
<FormControl>
<Switch
checked={field.value}
onCheckedChange={field.onChange}
/>
</FormControl>
</FormItem>
)}
/>
<FormField
control={form.control}
name="use_pipeline_uuid"
render={({ field }) => (
<FormItem className="flex flex-col justify-start gap-[0.8rem] h-[3.8rem]">
<FormLabel>{t('bots.bindPipeline')}</FormLabel>
<FormControl>
<Select onValueChange={field.onChange} {...field}>
<SelectTrigger className="bg-[#ffffff] dark:bg-[#2a2a2e]">
<SelectValue
placeholder={t('bots.selectPipeline')}
/>
</SelectTrigger>
<SelectContent className="fixed z-[1000]">
<SelectGroup>
{pipelineNameList.map((item) => (
<SelectItem
key={item.value}
value={item.value}
>
{item.label}
</SelectItem>
))}
</SelectGroup>
</SelectContent>
</Select>
</FormControl>
</FormItem>
)}
/>
</div>
{/* Webhook 地址显示(统一 Webhook 模式) */}
{webhookUrl && (
<FormItem>
<FormLabel>{t('bots.webhookUrl')}</FormLabel>
<div className="flex items-center gap-2">
<Input
ref={webhookInputRef}
value={webhookUrl}
readOnly
className="flex-1 bg-gray-50 dark:bg-gray-900"
onClick={(e) => {
// 点击输入框时自动全选
(e.target as HTMLInputElement).select();
}}
/>
<Button
type="button"
variant="outline"
size="sm"
onClick={copyToClipboard}
>
{t('common.copy')}
</Button>
</div>
<p className="text-sm text-gray-500 mt-1">
{t('bots.webhookUrlHint')}
</p>
</FormItem>
)}
</>
<FormField
control={form.control}
name="use_pipeline_uuid"
render={({ field }) => (
<FormItem className="flex flex-col justify-start gap-[0.8rem] h-[3.8rem]">
<FormLabel>{t('bots.bindPipeline')}</FormLabel>
<FormControl>
<Select onValueChange={field.onChange} {...field}>
<SelectTrigger className="bg-[#ffffff] dark:bg-[#2a2a2e]">
<SelectValue
placeholder={t('bots.selectPipeline')}
/>
</SelectTrigger>
<SelectContent className="fixed z-[1000]">
<SelectGroup>
{pipelineNameList.map((item) => (
<SelectItem key={item.value} value={item.value}>
{item.label}
</SelectItem>
))}
</SelectGroup>
</SelectContent>
</Select>
</FormControl>
</FormItem>
)}
/>
</div>
)}
<FormField

View File

@@ -11,6 +11,10 @@ import { Message } from '@/app/infra/entities/message';
import { toast } from 'sonner';
import AtBadge from './AtBadge';
import { Switch } from '@/components/ui/switch';
import {
PipelineWebSocketClient,
SessionType,
} from '@/app/infra/websocket/PipelineWebSocketClient';
interface MessageComponent {
type: 'At' | 'Plain';
@@ -31,17 +35,27 @@ export default function DebugDialog({
}: DebugDialogProps) {
const { t } = useTranslation();
const [selectedPipelineId, setSelectedPipelineId] = useState(pipelineId);
const [sessionType, setSessionType] = useState<'person' | 'group'>('person');
const [sessionType, setSessionType] = useState<SessionType>('person');
const [messages, setMessages] = useState<Message[]>([]);
const [inputValue, setInputValue] = useState('');
const [showAtPopover, setShowAtPopover] = useState(false);
const [hasAt, setHasAt] = useState(false);
const [isHovering, setIsHovering] = useState(false);
const [isStreaming, setIsStreaming] = useState(true);
const messagesEndRef = useRef<HTMLDivElement>(null);
const inputRef = useRef<HTMLInputElement>(null);
const popoverRef = useRef<HTMLDivElement>(null);
// WebSocket states
const [wsClient, setWsClient] = useState<PipelineWebSocketClient | null>(
null,
);
const [connectionStatus, setConnectionStatus] = useState<
'disconnected' | 'connecting' | 'connected'
>('disconnected');
const [pendingMessages, setPendingMessages] = useState<
Map<string, Message>
>(new Map());
const scrollToBottom = useCallback(() => {
// 使用setTimeout确保在DOM更新后执行滚动
setTimeout(() => {
@@ -57,37 +71,161 @@ export default function DebugDialog({
}, 0);
}, []);
const loadMessages = useCallback(
async (pipelineId: string) => {
try {
const response = await httpClient.getWebChatHistoryMessages(
pipelineId,
sessionType,
);
setMessages(response.messages);
} catch (error) {
console.error('Failed to load messages:', error);
}
},
[sessionType],
);
// 在useEffect中监听messages变化时滚动
// Scroll to bottom when messages change
useEffect(() => {
scrollToBottom();
}, [messages, scrollToBottom]);
// WebSocket connection setup
useEffect(() => {
if (open) {
setSelectedPipelineId(pipelineId);
loadMessages(pipelineId);
}
}, [open, pipelineId]);
if (!open) return;
useEffect(() => {
if (open) {
loadMessages(selectedPipelineId);
}
}, [sessionType, selectedPipelineId, open, loadMessages]);
const client = new PipelineWebSocketClient(
selectedPipelineId,
sessionType,
);
// Setup event handlers
client.onConnected = (data) => {
console.log('[DebugDialog] WebSocket connected:', data);
setConnectionStatus('connected');
// Load history messages after connection
client.loadHistory();
};
client.onHistory = (data) => {
console.log('[DebugDialog] History loaded:', data?.messages.length);
if (data) {
setMessages(data.messages);
}
};
client.onMessageSent = (data) => {
console.log('[DebugDialog] Message sent confirmed:', data);
if (data) {
// Update client message ID to server message ID
const clientMsgId = data.client_message_id;
const serverMsgId = data.server_message_id;
setMessages((prev) =>
prev.map((msg) =>
msg.id === -1 && pendingMessages.has(clientMsgId)
? { ...msg, id: serverMsgId }
: msg,
),
);
setPendingMessages((prev) => {
const newMap = new Map(prev);
newMap.delete(clientMsgId);
return newMap;
});
}
};
client.onMessageStart = (data) => {
console.log('[DebugDialog] Message start:', data);
if (data) {
const placeholderMessage: Message = {
id: data.message_id,
role: 'assistant',
content: '',
message_chain: [],
timestamp: data.timestamp,
};
setMessages((prev) => [...prev, placeholderMessage]);
}
};
client.onMessageChunk = (data) => {
if (data) {
// Update streaming message (content is cumulative)
setMessages((prev) =>
prev.map((msg) =>
msg.id === data.message_id
? {
...msg,
content: data.content,
message_chain: data.message_chain,
}
: msg,
),
);
}
};
client.onMessageComplete = (data) => {
console.log('[DebugDialog] Message complete:', data);
if (data) {
// Mark message as complete
setMessages((prev) =>
prev.map((msg) =>
msg.id === data.message_id
? {
...msg,
content: data.final_content,
message_chain: data.message_chain,
}
: msg,
),
);
}
};
client.onMessageError = (data) => {
console.error('[DebugDialog] Message error:', data);
if (data) {
toast.error(`Message error: ${data.error}`);
}
};
client.onPluginMessage = (data) => {
console.log('[DebugDialog] Plugin message:', data);
if (data) {
const pluginMessage: Message = {
id: data.message_id,
role: 'assistant',
content: data.content,
message_chain: data.message_chain,
timestamp: data.timestamp,
};
setMessages((prev) => [...prev, pluginMessage]);
}
};
client.onError = (data) => {
console.error('[DebugDialog] WebSocket error:', data);
if (data) {
toast.error(`WebSocket error: ${data.error}`);
}
};
client.onDisconnected = () => {
console.log('[DebugDialog] WebSocket disconnected');
setConnectionStatus('disconnected');
};
// Connect to WebSocket
setConnectionStatus('connecting');
client
.connect(httpClient.getSessionSync())
.then(() => {
console.log('[DebugDialog] WebSocket connection established');
})
.catch((err) => {
console.error('[DebugDialog] Failed to connect WebSocket:', err);
toast.error('Failed to connect to server');
setConnectionStatus('disconnected');
});
setWsClient(client);
// Cleanup on unmount or session type change
return () => {
console.log('[DebugDialog] Cleaning up WebSocket connection');
client.disconnect();
};
}, [open, selectedPipelineId, sessionType]); // Reconnect when session type changes
useEffect(() => {
const handleClickOutside = (event: MouseEvent) => {
@@ -150,6 +288,12 @@ export default function DebugDialog({
const sendMessage = async () => {
if (!inputValue.trim() && !hasAt) return;
// Check WebSocket connection
if (!wsClient || connectionStatus !== 'connected') {
toast.error('Not connected to server');
return;
}
try {
const messageChain = [];
@@ -170,7 +314,7 @@ export default function DebugDialog({
});
if (hasAt) {
// for showing
// For display
text_content = '@webchatbot' + text_content;
}
@@ -181,97 +325,26 @@ export default function DebugDialog({
timestamp: new Date().toISOString(),
message_chain: messageChain,
};
// 根据isStreaming状态决定使用哪种传输方式
if (isStreaming) {
// streaming
// 创建初始bot消息
const placeholderRandomId = Math.floor(Math.random() * 1000000);
const botMessagePlaceholder: Message = {
id: placeholderRandomId,
role: 'assistant',
content: 'Generating...',
timestamp: new Date().toISOString(),
message_chain: [{ type: 'Plain', text: 'Generating...' }],
};
// 添加用户消息和初始bot消息到状态
// Add user message to UI immediately
setMessages((prevMessages) => [...prevMessages, userMessage]);
setInputValue('');
setHasAt(false);
setMessages((prevMessages) => [
...prevMessages,
userMessage,
botMessagePlaceholder,
]);
setInputValue('');
setHasAt(false);
try {
await httpClient.sendStreamingWebChatMessage(
sessionType,
messageChain,
selectedPipelineId,
(data) => {
// 处理流式响应数据
console.log('data', data);
if (data.message) {
// 更新完整内容
// Send via WebSocket
const clientMessageId = wsClient.sendMessage(messageChain);
setMessages((prevMessages) => {
const updatedMessages = [...prevMessages];
const botMessageIndex = updatedMessages.findIndex(
(message) => message.id === placeholderRandomId,
);
if (botMessageIndex !== -1) {
updatedMessages[botMessageIndex] = {
...updatedMessages[botMessageIndex],
content: data.message.content,
message_chain: [
{ type: 'Plain', text: data.message.content },
],
};
}
return updatedMessages;
});
}
},
() => {},
(error) => {
// 处理错误
console.error('Streaming error:', error);
if (sessionType === 'person') {
toast.error(t('pipelines.debugDialog.sendFailed'));
}
},
);
} catch (error) {
console.error('Failed to send streaming message:', error);
if (sessionType === 'person') {
toast.error(t('pipelines.debugDialog.sendFailed'));
}
}
} else {
// non-streaming
setMessages((prevMessages) => [...prevMessages, userMessage]);
setInputValue('');
setHasAt(false);
// Track pending message for ID mapping
setPendingMessages((prev) => {
const newMap = new Map(prev);
newMap.set(clientMessageId, userMessage);
return newMap;
});
const response = await httpClient.sendWebChatMessage(
sessionType,
messageChain,
selectedPipelineId,
180000,
);
setMessages((prevMessages) => [...prevMessages, response.message]);
}
} catch (
// eslint-disable-next-line @typescript-eslint/no-explicit-any
error: any
) {
console.log(error, 'type of error', typeof error);
console.error('Failed to send message:', error);
if (!error.message.includes('timeout') && sessionType === 'person') {
toast.error(t('pipelines.debugDialog.sendFailed'));
}
console.log('[DebugDialog] Message sent:', clientMessageId);
} catch (error) {
console.error('[DebugDialog] Failed to send message:', error);
toast.error(t('pipelines.debugDialog.sendFailed'));
} finally {
inputRef.current?.focus();
}
@@ -390,12 +463,6 @@ export default function DebugDialog({
</ScrollArea>
<div className="p-4 pb-0 bg-white dark:bg-black flex gap-2">
<div className="flex items-center gap-2">
<span className="text-sm text-gray-600">
{t('pipelines.debugDialog.streaming')}
</span>
<Switch checked={isStreaming} onCheckedChange={setIsStreaming} />
</div>
<div className="flex-1 flex items-center gap-2">
{hasAt && (
<AtBadge targetName="webchatbot" onRemove={handleAtRemove} />

View File

@@ -141,7 +141,6 @@ export interface Bot {
use_pipeline_uuid?: string;
created_at?: string;
updated_at?: string;
adapter_runtime_values?: object;
}
export interface ApiRespKnowledgeBases {

View File

@@ -0,0 +1,394 @@
/**
* Pipeline WebSocket Client
*
* Provides real-time bidirectional communication for pipeline debugging.
* Supports person and group session isolation.
*/
import { Message } from '@/app/infra/entities/message';
export type SessionType = 'person' | 'group';
export interface WebSocketEventData {
// Connected event
connected?: {
connection_id: string;
session_type: SessionType;
pipeline_uuid: string;
};
// History event
history?: {
messages: Message[];
has_more: boolean;
};
// Message sent confirmation
message_sent?: {
client_message_id: string;
server_message_id: number;
timestamp: string;
};
// Message start
message_start?: {
message_id: number;
role: 'assistant';
timestamp: string;
reply_to: number;
};
// Message chunk
message_chunk?: {
message_id: number;
content: string;
message_chain: object[];
timestamp: string;
};
// Message complete
message_complete?: {
message_id: number;
final_content: string;
message_chain: object[];
timestamp: string;
};
// Message error
message_error?: {
message_id: number;
error: string;
error_code?: string;
};
// Interrupted
interrupted?: {
message_id: number;
partial_content: string;
};
// Plugin message
plugin_message?: {
message_id: number;
role: 'assistant';
content: string;
message_chain: object[];
timestamp: string;
source: 'plugin';
};
// Error
error?: {
error: string;
error_code: string;
details?: object;
};
// Pong
pong?: {
timestamp: number;
};
}
export class PipelineWebSocketClient {
private ws: WebSocket | null = null;
private pipelineId: string;
private sessionType: SessionType;
private reconnectAttempts = 0;
private maxReconnectAttempts = 5;
private pingInterval: NodeJS.Timeout | null = null;
private reconnectTimeout: NodeJS.Timeout | null = null;
private isManualDisconnect = false;
// Event callbacks
public onConnected?: (data: WebSocketEventData['connected']) => void;
public onHistory?: (data: WebSocketEventData['history']) => void;
public onMessageSent?: (data: WebSocketEventData['message_sent']) => void;
public onMessageStart?: (data: WebSocketEventData['message_start']) => void;
public onMessageChunk?: (data: WebSocketEventData['message_chunk']) => void;
public onMessageComplete?: (
data: WebSocketEventData['message_complete'],
) => void;
public onMessageError?: (data: WebSocketEventData['message_error']) => void;
public onInterrupted?: (data: WebSocketEventData['interrupted']) => void;
public onPluginMessage?: (data: WebSocketEventData['plugin_message']) => void;
public onError?: (data: WebSocketEventData['error']) => void;
public onDisconnected?: () => void;
constructor(pipelineId: string, sessionType: SessionType) {
this.pipelineId = pipelineId;
this.sessionType = sessionType;
}
/**
* Connect to WebSocket server
*/
connect(token: string): Promise<void> {
return new Promise((resolve, reject) => {
this.isManualDisconnect = false;
const wsUrl = this.buildWebSocketUrl();
console.log(`[WebSocket] Connecting to ${wsUrl}...`);
try {
this.ws = new WebSocket(wsUrl);
} catch (error) {
console.error('[WebSocket] Failed to create WebSocket:', error);
reject(error);
return;
}
this.ws.onopen = () => {
console.log('[WebSocket] Connection opened');
// Send connect event with session type and token
this.send('connect', {
pipeline_uuid: this.pipelineId,
session_type: this.sessionType,
token,
});
// Start ping interval
this.startPing();
// Reset reconnect attempts on successful connection
this.reconnectAttempts = 0;
resolve();
};
this.ws.onmessage = (event) => {
this.handleMessage(event);
};
this.ws.onerror = (error) => {
console.error('[WebSocket] Error:', error);
reject(error);
};
this.ws.onclose = (event) => {
console.log(
`[WebSocket] Connection closed: code=${event.code}, reason=${event.reason}`,
);
this.handleDisconnect();
};
});
}
/**
* Handle incoming WebSocket message
*/
private handleMessage(event: MessageEvent) {
try {
const message = JSON.parse(event.data);
const { type, data } = message;
console.log(`[WebSocket] Received: ${type}`, data);
switch (type) {
case 'connected':
this.onConnected?.(data);
break;
case 'history':
this.onHistory?.(data);
break;
case 'message_sent':
this.onMessageSent?.(data);
break;
case 'message_start':
this.onMessageStart?.(data);
break;
case 'message_chunk':
this.onMessageChunk?.(data);
break;
case 'message_complete':
this.onMessageComplete?.(data);
break;
case 'message_error':
this.onMessageError?.(data);
break;
case 'interrupted':
this.onInterrupted?.(data);
break;
case 'plugin_message':
this.onPluginMessage?.(data);
break;
case 'error':
this.onError?.(data);
break;
case 'pong':
// Heartbeat response, no action needed
break;
default:
console.warn(`[WebSocket] Unknown message type: ${type}`);
}
} catch (error) {
console.error('[WebSocket] Failed to parse message:', error);
}
}
/**
* Send message to server
*/
sendMessage(messageChain: object[]): string {
const clientMessageId = this.generateMessageId();
this.send('send_message', {
message_chain: messageChain,
client_message_id: clientMessageId,
});
return clientMessageId;
}
/**
* Load history messages
*/
loadHistory(beforeMessageId?: number, limit?: number) {
this.send('load_history', {
before_message_id: beforeMessageId,
limit,
});
}
/**
* Interrupt streaming message
*/
interrupt(messageId: number) {
this.send('interrupt', { message_id: messageId });
}
/**
* Send event to server
*/
private send(type: string, data: object) {
if (this.ws?.readyState === WebSocket.OPEN) {
const message = JSON.stringify({ type, data });
console.log(`[WebSocket] Sending: ${type}`, data);
this.ws.send(message);
} else {
console.warn(
`[WebSocket] Cannot send message, connection not open (state: ${this.ws?.readyState})`,
);
}
}
/**
* Start ping interval (heartbeat)
*/
private startPing() {
this.stopPing();
this.pingInterval = setInterval(() => {
this.send('ping', { timestamp: Date.now() });
}, 30000); // Ping every 30 seconds
}
/**
* Stop ping interval
*/
private stopPing() {
if (this.pingInterval) {
clearInterval(this.pingInterval);
this.pingInterval = null;
}
}
/**
* Handle disconnection
*/
private handleDisconnect() {
this.stopPing();
this.onDisconnected?.();
// Auto reconnect if not manual disconnect
if (
!this.isManualDisconnect &&
this.reconnectAttempts < this.maxReconnectAttempts
) {
const delay = Math.min(2000 * Math.pow(2, this.reconnectAttempts), 30000);
console.log(
`[WebSocket] Reconnecting in ${delay}ms (attempt ${this.reconnectAttempts + 1}/${this.maxReconnectAttempts})`,
);
this.reconnectTimeout = setTimeout(() => {
this.reconnectAttempts++;
// Note: Need to get token again, should be handled by caller
console.warn(
'[WebSocket] Auto-reconnect requires token, please reconnect manually',
);
}, delay);
} else if (this.reconnectAttempts >= this.maxReconnectAttempts) {
console.error(
'[WebSocket] Max reconnect attempts reached, giving up',
);
}
}
/**
* Disconnect from server
*/
disconnect() {
this.isManualDisconnect = true;
this.stopPing();
if (this.reconnectTimeout) {
clearTimeout(this.reconnectTimeout);
this.reconnectTimeout = null;
}
if (this.ws) {
this.ws.close(1000, 'Client disconnect');
this.ws = null;
}
console.log('[WebSocket] Disconnected');
}
/**
* Build WebSocket URL
*/
private buildWebSocketUrl(): string {
// Get current base URL
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const host = window.location.host;
return `${protocol}//${host}/api/v1/pipelines/${this.pipelineId}/chat/ws`;
}
/**
* Generate unique client message ID
*/
private generateMessageId(): string {
return `${Date.now()}-${Math.random().toString(36).substr(2, 9)}`;
}
/**
* Get connection state
*/
getState():
| 'CONNECTING'
| 'OPEN'
| 'CLOSING'
| 'CLOSED'
| 'DISCONNECTED' {
if (!this.ws) return 'DISCONNECTED';
switch (this.ws.readyState) {
case WebSocket.CONNECTING:
return 'CONNECTING';
case WebSocket.OPEN:
return 'OPEN';
case WebSocket.CLOSING:
return 'CLOSING';
case WebSocket.CLOSED:
return 'CLOSED';
default:
return 'DISCONNECTED';
}
}
/**
* Check if connected
*/
isConnected(): boolean {
return this.ws?.readyState === WebSocket.OPEN;
}
}

View File

@@ -39,9 +39,7 @@ const enUS = {
deleteSuccess: 'Deleted successfully',
deleteError: 'Delete failed: ',
addRound: 'Add Round',
copy: 'Copy',
copySuccess: 'Copy Successfully',
copyFailed: 'Copy Failed',
test: 'Test',
forgotPassword: 'Forgot Password?',
loading: 'Loading...',
@@ -151,10 +149,6 @@ const enUS = {
log: 'Log',
configuration: 'Configuration',
logs: 'Logs',
webhookUrl: 'Webhook Callback URL',
webhookUrlCopied: 'Webhook URL copied',
webhookUrlHint:
'Click the input to select all, then press Ctrl+C (Mac: Cmd+C) to copy, or click the button',
},
plugins: {
title: 'Plugins',

View File

@@ -40,9 +40,7 @@ const jaJP = {
deleteSuccess: '削除に成功しました',
deleteError: '削除に失敗しました:',
addRound: 'ラウンドを追加',
copy: 'コピー',
copySuccess: 'コピーに成功しました',
copyFailed: 'コピーに失敗しました',
test: 'テスト',
forgotPassword: 'パスワードを忘れた?',
loading: '読み込み中...',
@@ -153,10 +151,6 @@ const jaJP = {
log: 'ログ',
configuration: '設定',
logs: 'ログ',
webhookUrl: 'Webhook コールバック URL',
webhookUrlCopied: 'Webhook URL をコピーしました',
webhookUrlHint:
'入力ボックスをクリックして全選択し、Ctrl+C (Mac: Cmd+C) でコピーするか、右側のボタンをクリックしてください',
},
plugins: {
title: 'プラグイン',

View File

@@ -39,9 +39,7 @@ const zhHans = {
deleteSuccess: '删除成功',
deleteError: '删除失败:',
addRound: '添加回合',
copy: '复制',
copySuccess: '复制成功',
copyFailed: '复制失败',
test: '测试',
forgotPassword: '忘记密码?',
loading: '加载中...',
@@ -148,10 +146,6 @@ const zhHans = {
log: '日志',
configuration: '配置',
logs: '日志',
webhookUrl: 'Webhook 回调地址',
webhookUrlCopied: 'Webhook 地址已复制',
webhookUrlHint:
'点击输入框自动全选,然后按 Ctrl+C (Mac: Cmd+C) 复制,或点击右侧按钮',
},
plugins: {
title: '插件管理',

View File

@@ -39,9 +39,7 @@ const zhHant = {
deleteSuccess: '刪除成功',
deleteError: '刪除失敗:',
addRound: '新增回合',
copy: '複製',
copySuccess: '複製成功',
copyFailed: '複製失敗',
test: '測試',
forgotPassword: '忘記密碼?',
loading: '載入中...',
@@ -148,10 +146,6 @@ const zhHant = {
log: '日誌',
configuration: '設定',
logs: '日誌',
webhookUrl: 'Webhook 回調位址',
webhookUrlCopied: 'Webhook 位址已複製',
webhookUrlHint:
'點擊輸入框自動全選,然後按 Ctrl+C (Mac: Cmd+C) 複製,或點擊右側按鈕',
},
plugins: {
title: '外掛管理',