mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 11:29:39 +08:00
Merge pull request #1571 from fdc310/streaming_feature
feat:add streaming output and pipeline stream
This commit is contained in:
47
fix.MD
Normal file
47
fix.MD
Normal file
@@ -0,0 +1,47 @@
|
||||
## 底层模型请求器
|
||||
|
||||
- pkg/provider/modelmgr/requesters/...
|
||||
|
||||
给 invoke_llm 加个 stream: bool 参数,并允许 invoke_llm 返回两种参数:原来的 llm_entities.Message(非流式)和 返回 llm_entities.MessageChunk(流式,需要新增这个实体)的 AsyncGenerator
|
||||
|
||||
## Runner
|
||||
|
||||
- pkg/provider/runners/...
|
||||
|
||||
每个runner的run方法也允许传入stream: bool。
|
||||
|
||||
现在的run方法本身就是生成器(AsyncGenerator),因为agent是有多回合的,会生成多条Message。但现在需要支持文本消息可以分段。
|
||||
|
||||
现在run方法应该返回 AsyncGenerator[ Union[ Message, AsyncGenerator[MessageChunk] ] ]。
|
||||
|
||||
对于 local agent 的实现上,调用模型invoke_llm时,传入stream,当发现模型返回的是Message时,即按照现在的写法操作Message;当返回的是 AsyncGenerator 时,需要 yield MessageChunk 给上层,同时需要注意判断工具调用。
|
||||
|
||||
## 流水线
|
||||
|
||||
- pkg/pipeline/process/handlers/chat.py
|
||||
|
||||
之前这里就已经有一个生成器写法了,用于处理 AsyncGenerator[Message],但现在需要加上一个判断,如果yield出来的是 Message 则按照现在的处理;如果yield出来的是 AsyncGenerator,那么就需要再 async for 一层;
|
||||
|
||||
因为流水线是基于责任链模式设计的,这里的生成结果只需要放入 Query 对象中,供下一层处理。
|
||||
|
||||
所以需要在 Query 对象中支持存入MessageChunk,现在只支持存 Message 到 resp_messages,这里得设计一下。
|
||||
|
||||
## 回复阶段
|
||||
|
||||
最终会在 pkg/pipeline/respback/respback.py 中检出 query 中的信息并发回,这里也要改成支持 MessagChunk 的。
|
||||
|
||||
这里应该判断适配器是否支持流式,若不支持,应该等待所有 MessageChunk 生成,拼接成 Message 再转换成 MessageChain 调用 send_message();
|
||||
|
||||
若支持,则uuid生成一个message id,使用该message id调用适配器的 reply_message_chunk 方法。
|
||||
|
||||
## 机器人适配器
|
||||
|
||||
因为机器人可能会由于用户配置项不同而表现为对流式的支持性不同,比如飞书默认不支持流式,需要用户额外配置卡片。
|
||||
|
||||
所以需要新增一个方法 `is_stream_output_supported() -> bool`,这个让每个适配器来判断并返回是否支持流式;
|
||||
|
||||
在发送时,得加两个方法 `send_message_chunk(target_type: str, target_id: str, message_id: , message: MessageChain)`
|
||||
|
||||
message_id 确定同一条消息,由调用方生成;
|
||||
|
||||
`reply_message_chunk(message_source: MessageEvent, message: MessageChain)`
|
||||
@@ -253,6 +253,43 @@ class DingTalkClient:
|
||||
await self.logger.error(f'failed to send proactive massage to group: {traceback.format_exc()}')
|
||||
raise Exception(f'failed to send proactive massage to group: {traceback.format_exc()}')
|
||||
|
||||
async def create_and_card(
|
||||
self, temp_card_id: str, incoming_message: dingtalk_stream.ChatbotMessage, quote_origin: bool = False
|
||||
):
|
||||
content_key = 'content'
|
||||
card_data = {content_key: ''}
|
||||
|
||||
card_instance = dingtalk_stream.AICardReplier(self.client, incoming_message)
|
||||
# print(card_instance)
|
||||
# 先投放卡片: https://open.dingtalk.com/document/orgapp/create-and-deliver-cards
|
||||
card_instance_id = await card_instance.async_create_and_deliver_card(
|
||||
temp_card_id,
|
||||
card_data,
|
||||
)
|
||||
return card_instance, card_instance_id
|
||||
|
||||
async def send_card_message(self, card_instance, card_instance_id: str, content: str, is_final: bool):
|
||||
content_key = 'content'
|
||||
try:
|
||||
await card_instance.async_streaming(
|
||||
card_instance_id,
|
||||
content_key=content_key,
|
||||
content_value=content,
|
||||
append=False,
|
||||
finished=is_final,
|
||||
failed=False,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.exception(e)
|
||||
await card_instance.async_streaming(
|
||||
card_instance_id,
|
||||
content_key=content_key,
|
||||
content_value='',
|
||||
append=False,
|
||||
finished=is_final,
|
||||
failed=True,
|
||||
)
|
||||
|
||||
async def start(self):
|
||||
"""启动 WebSocket 连接,监听消息"""
|
||||
await self.client.start()
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import json
|
||||
|
||||
import quart
|
||||
|
||||
from ... import group
|
||||
@@ -9,10 +11,18 @@ class WebChatDebugRouterGroup(group.RouterGroup):
|
||||
@self.route('/send', methods=['POST'])
|
||||
async def send_message(pipeline_uuid: str) -> str:
|
||||
"""Send a message to the pipeline for debugging"""
|
||||
|
||||
async def stream_generator(generator):
|
||||
yield 'data: {"type": "start"}\n\n'
|
||||
async for message in generator:
|
||||
yield f'data: {json.dumps({"message": message})}\n\n'
|
||||
yield 'data: {"type": "end"}\n\n'
|
||||
|
||||
try:
|
||||
data = await quart.request.get_json()
|
||||
session_type = data.get('session_type', 'person')
|
||||
message_chain_obj = data.get('message', [])
|
||||
is_stream = data.get('is_stream', False)
|
||||
|
||||
if not message_chain_obj:
|
||||
return self.http_status(400, -1, 'message is required')
|
||||
@@ -25,13 +35,33 @@ class WebChatDebugRouterGroup(group.RouterGroup):
|
||||
if not webchat_adapter:
|
||||
return self.http_status(404, -1, 'WebChat adapter not found')
|
||||
|
||||
result = await webchat_adapter.send_webchat_message(pipeline_uuid, session_type, message_chain_obj)
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'message': result,
|
||||
if is_stream:
|
||||
generator = webchat_adapter.send_webchat_message(
|
||||
pipeline_uuid, session_type, message_chain_obj, is_stream
|
||||
)
|
||||
# 设置正确的响应头
|
||||
headers = {
|
||||
'Content-Type': 'text/event-stream',
|
||||
'Transfer-Encoding': 'chunked',
|
||||
'Cache-Control': 'no-cache',
|
||||
'Connection': 'keep-alive'
|
||||
}
|
||||
)
|
||||
return quart.Response(stream_generator(generator), mimetype='text/event-stream',headers=headers)
|
||||
|
||||
else: # non-stream
|
||||
result = None
|
||||
async for message in webchat_adapter.send_webchat_message(
|
||||
pipeline_uuid, session_type, message_chain_obj
|
||||
):
|
||||
result = message
|
||||
if result is not None:
|
||||
return self.success(
|
||||
data={
|
||||
'message': result,
|
||||
}
|
||||
)
|
||||
else:
|
||||
return self.http_status(400, -1, 'message is required')
|
||||
|
||||
except Exception as e:
|
||||
return self.http_status(500, -1, f'Internal server error: {str(e)}')
|
||||
|
||||
@@ -101,7 +101,7 @@ class LLMModelsService:
|
||||
model=runtime_llm_model,
|
||||
messages=[llm_entities.Message(role='user', content='Hello, world!')],
|
||||
funcs=[],
|
||||
extra_args={},
|
||||
extra_args=model_data.get('extra_args', {}),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -87,7 +87,9 @@ class Query(pydantic.BaseModel):
|
||||
"""使用的函数,由前置处理器阶段设置"""
|
||||
|
||||
resp_messages: (
|
||||
typing.Optional[list[llm_entities.Message]] | typing.Optional[list[platform_message.MessageChain]]
|
||||
typing.Optional[list[llm_entities.Message]]
|
||||
| typing.Optional[list[platform_message.MessageChain]]
|
||||
| typing.Optional[list[llm_entities.MessageChunk]]
|
||||
) = []
|
||||
"""由Process阶段生成的回复消息对象列表"""
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ class ContentFilterStage(stage.PipelineStage):
|
||||
if query.pipeline_config['safety']['content-filter']['scope'] == 'output-msg':
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
if not message.strip():
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
for filter in self.filter_chain:
|
||||
if filter_entities.EnableStage.PRE in filter.enable_stages:
|
||||
|
||||
@@ -93,12 +93,20 @@ class RuntimePipeline:
|
||||
query.message_event, platform_events.GroupMessage
|
||||
):
|
||||
result.user_notice.insert(0, platform_message.At(query.message_event.sender.id))
|
||||
|
||||
await query.adapter.reply_message(
|
||||
message_source=query.message_event,
|
||||
message=result.user_notice,
|
||||
quote_origin=query.pipeline_config['output']['misc']['quote-origin'],
|
||||
)
|
||||
if await query.adapter.is_stream_output_supported():
|
||||
await query.adapter.reply_message_chunk(
|
||||
message_source=query.message_event,
|
||||
bot_message=query.resp_messages[-1],
|
||||
message=result.user_notice,
|
||||
quote_origin=query.pipeline_config['output']['misc']['quote-origin'],
|
||||
is_final=[msg.is_final for msg in query.resp_messages][0]
|
||||
)
|
||||
else:
|
||||
await query.adapter.reply_message(
|
||||
message_source=query.message_event,
|
||||
message=result.user_notice,
|
||||
quote_origin=query.pipeline_config['output']['misc']['quote-origin'],
|
||||
)
|
||||
if result.debug_notice:
|
||||
self.ap.logger.debug(result.debug_notice)
|
||||
if result.console_notice:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import typing
|
||||
import traceback
|
||||
|
||||
@@ -22,11 +23,11 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
|
||||
"""Process"""
|
||||
# Call API
|
||||
# generator
|
||||
"""处理"""
|
||||
# 调API
|
||||
# 生成器
|
||||
|
||||
# Trigger plugin event
|
||||
# 触发插件事件
|
||||
event_class = (
|
||||
events.PersonNormalMessageReceived
|
||||
if query.launcher_type == core_entities.LauncherTypes.PERSON
|
||||
@@ -46,7 +47,6 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
if event_ctx.is_prevented_default():
|
||||
if event_ctx.event.reply is not None:
|
||||
mc = platform_message.MessageChain(event_ctx.event.reply)
|
||||
|
||||
query.resp_messages.append(mc)
|
||||
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
@@ -54,10 +54,14 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.INTERRUPT, new_query=query)
|
||||
else:
|
||||
if event_ctx.event.alter is not None:
|
||||
# if isinstance(event_ctx.event, str): # Currently not considering multi-modal alter
|
||||
# if isinstance(event_ctx.event, str): # 现在暂时不考虑多模态alter
|
||||
query.user_message.content = event_ctx.event.alter
|
||||
|
||||
text_length = 0
|
||||
try:
|
||||
is_stream = await query.adapter.is_stream_output_supported()
|
||||
except AttributeError:
|
||||
is_stream = False
|
||||
|
||||
try:
|
||||
for r in runner_module.preregistered_runners:
|
||||
@@ -65,22 +69,42 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
runner = r(self.ap, query.pipeline_config)
|
||||
break
|
||||
else:
|
||||
raise ValueError(f'Request runner not found: {query.pipeline_config["ai"]["runner"]["runner"]}')
|
||||
raise ValueError(f'未找到请求运行器: {query.pipeline_config["ai"]["runner"]["runner"]}')
|
||||
if is_stream:
|
||||
resp_message_id = uuid.uuid4()
|
||||
await query.adapter.create_message_card(str(resp_message_id), query.message_event)
|
||||
async for result in runner.run(query):
|
||||
result.resp_message_id = str(resp_message_id)
|
||||
if query.resp_messages:
|
||||
query.resp_messages.pop()
|
||||
if query.resp_message_chain:
|
||||
query.resp_message_chain.pop()
|
||||
|
||||
async for result in runner.run(query):
|
||||
query.resp_messages.append(result)
|
||||
query.resp_messages.append(result)
|
||||
self.ap.logger.info(f'对话({query.query_id})流式响应: {self.cut_str(result.readable_str())}')
|
||||
|
||||
self.ap.logger.info(f'Response({query.query_id}): {self.cut_str(result.readable_str())}')
|
||||
if result.content is not None:
|
||||
text_length += len(result.content)
|
||||
|
||||
if result.content is not None:
|
||||
text_length += len(result.content)
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
else:
|
||||
async for result in runner.run(query):
|
||||
query.resp_messages.append(result)
|
||||
|
||||
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
|
||||
|
||||
if result.content is not None:
|
||||
text_length += len(result.content)
|
||||
|
||||
yield entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
query.session.using_conversation.messages.append(query.user_message)
|
||||
|
||||
query.session.using_conversation.messages.extend(query.resp_messages)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Request failed({query.query_id}): {type(e).__name__} {str(e)}')
|
||||
self.ap.logger.error(f'对话({query.query_id})请求失败: {type(e).__name__} {str(e)}')
|
||||
traceback.print_exc()
|
||||
|
||||
hide_exception_info = query.pipeline_config['output']['misc']['hide-exception']
|
||||
|
||||
@@ -93,4 +117,4 @@ class ChatMessageHandler(handler.MessageHandler):
|
||||
)
|
||||
finally:
|
||||
# TODO statistics
|
||||
pass
|
||||
pass
|
||||
@@ -7,6 +7,10 @@ import asyncio
|
||||
from ...platform.types import events as platform_events
|
||||
from ...platform.types import message as platform_message
|
||||
|
||||
from ...provider import entities as llm_entities
|
||||
|
||||
|
||||
|
||||
from .. import stage, entities
|
||||
from ...core import entities as core_entities
|
||||
|
||||
@@ -36,10 +40,22 @@ class SendResponseBackStage(stage.PipelineStage):
|
||||
|
||||
quote_origin = query.pipeline_config['output']['misc']['quote-origin']
|
||||
|
||||
await query.adapter.reply_message(
|
||||
message_source=query.message_event,
|
||||
message=query.resp_message_chain[-1],
|
||||
quote_origin=quote_origin,
|
||||
)
|
||||
has_chunks = any(isinstance(msg, llm_entities.MessageChunk) for msg in query.resp_messages)
|
||||
# TODO 命令与流式的兼容性问题
|
||||
if await query.adapter.is_stream_output_supported() and has_chunks:
|
||||
is_final = [msg.is_final for msg in query.resp_messages][0]
|
||||
await query.adapter.reply_message_chunk(
|
||||
message_source=query.message_event,
|
||||
bot_message=query.resp_messages[-1],
|
||||
message=query.resp_message_chain[-1],
|
||||
quote_origin=quote_origin,
|
||||
is_final=is_final,
|
||||
)
|
||||
else:
|
||||
await query.adapter.reply_message(
|
||||
message_source=query.message_event,
|
||||
message=query.resp_message_chain[-1],
|
||||
quote_origin=quote_origin,
|
||||
)
|
||||
|
||||
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)
|
||||
|
||||
@@ -61,14 +61,40 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def reply_message_chunk(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
bot_message: dict,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
is_final: bool = False,
|
||||
):
|
||||
"""回复消息(流式输出)
|
||||
Args:
|
||||
message_source (platform.types.MessageEvent): 消息源事件
|
||||
message_id (int): 消息ID
|
||||
message (platform.types.MessageChain): 消息链
|
||||
quote_origin (bool, optional): 是否引用原消息. Defaults to False.
|
||||
is_final (bool, optional): 流式是否结束. Defaults to False.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def create_message_card(self, message_id: typing.Type[str, int], event: platform_events.MessageEvent) -> bool:
|
||||
"""创建卡片消息
|
||||
Args:
|
||||
message_id (str): 消息ID
|
||||
event (platform_events.MessageEvent): 消息源事件
|
||||
"""
|
||||
return False
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
"""获取账号是否在指定群被禁言"""
|
||||
raise NotImplementedError
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_message.Event],
|
||||
callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None],
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
|
||||
):
|
||||
"""注册事件监听器
|
||||
|
||||
@@ -80,8 +106,8 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
|
||||
|
||||
def unregister_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_message.Event],
|
||||
callback: typing.Callable[[platform_message.Event, MessagePlatformAdapter], None],
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
callback: typing.Callable[[platform_events.Event, MessagePlatformAdapter], None],
|
||||
):
|
||||
"""注销事件监听器
|
||||
|
||||
@@ -95,6 +121,10 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
|
||||
"""异步运行"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def is_stream_output_supported(self) -> bool:
|
||||
"""是否支持流式输出"""
|
||||
return False
|
||||
|
||||
async def kill(self) -> bool:
|
||||
"""关闭适配器
|
||||
|
||||
@@ -136,7 +166,7 @@ class EventConverter:
|
||||
"""事件转换器基类"""
|
||||
|
||||
@staticmethod
|
||||
def yiri2target(event: typing.Type[platform_message.Event]):
|
||||
def yiri2target(event: typing.Type[platform_events.Event]):
|
||||
"""将源平台事件转换为目标平台事件
|
||||
|
||||
Args:
|
||||
@@ -148,7 +178,7 @@ class EventConverter:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def target2yiri(event: typing.Any) -> platform_message.Event:
|
||||
def target2yiri(event: typing.Any) -> platform_events.Event:
|
||||
"""将目标平台事件的调用参数转换为源平台的事件参数对象
|
||||
|
||||
Args:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from re import S
|
||||
import traceback
|
||||
import typing
|
||||
from libs.dingtalk_api.dingtalkevent import DingTalkEvent
|
||||
@@ -99,11 +100,15 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
||||
message_converter: DingTalkMessageConverter = DingTalkMessageConverter()
|
||||
event_converter: DingTalkEventConverter = DingTalkEventConverter()
|
||||
config: dict
|
||||
card_instance_id_dict: dict # 回复卡片消息字典,key为消息id,value为回复卡片实例id,用于在流式消息时判断是否发送到指定卡片
|
||||
seq: int # 消息顺序,直接以seq作为标识
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
self.logger = logger
|
||||
self.card_instance_id_dict = {}
|
||||
# self.seq = 1
|
||||
required_keys = [
|
||||
'client_id',
|
||||
'client_secret',
|
||||
@@ -139,6 +144,34 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
||||
content, at = await DingTalkMessageConverter.yiri2target(message)
|
||||
await self.bot.send_message(content, incoming_message, at)
|
||||
|
||||
async def reply_message_chunk(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
bot_message,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
is_final: bool = False,
|
||||
):
|
||||
# event = await DingTalkEventConverter.yiri2target(
|
||||
# message_source,
|
||||
# )
|
||||
# incoming_message = event.incoming_message
|
||||
|
||||
# msg_id = incoming_message.message_id
|
||||
message_id = bot_message.resp_message_id
|
||||
msg_seq = bot_message.msg_sequence
|
||||
|
||||
if (msg_seq - 1) % 8 == 0 or is_final:
|
||||
|
||||
content, at = await DingTalkMessageConverter.yiri2target(message)
|
||||
|
||||
card_instance, card_instance_id = self.card_instance_id_dict[message_id]
|
||||
# print(card_instance_id)
|
||||
await self.bot.send_card_message(card_instance, card_instance_id, content, is_final)
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
# self.seq = 1 # 消息回复结束之后重置seq
|
||||
self.card_instance_id_dict.pop(message_id) # 消息回复结束之后删除卡片实例id
|
||||
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
content = await DingTalkMessageConverter.yiri2target(message)
|
||||
if target_type == 'person':
|
||||
@@ -146,6 +179,20 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
|
||||
if target_type == 'group':
|
||||
await self.bot.send_proactive_message_to_group(target_id, content)
|
||||
|
||||
async def is_stream_output_supported(self) -> bool:
|
||||
is_stream = False
|
||||
if self.config.get('enable-stream-reply', None):
|
||||
is_stream = True
|
||||
return is_stream
|
||||
|
||||
async def create_message_card(self, message_id, event):
|
||||
card_template_id = self.config['card_template_id']
|
||||
incoming_message = event.source_platform_object.incoming_message
|
||||
# message_id = incoming_message.message_id
|
||||
card_instance, card_instance_id = await self.bot.create_and_card(card_template_id, incoming_message)
|
||||
self.card_instance_id_dict[message_id] = (card_instance, card_instance_id)
|
||||
return True
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
|
||||
@@ -46,6 +46,23 @@ spec:
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
- name: enable-stream-reply
|
||||
label:
|
||||
en_US: Enable Stream Reply Mode
|
||||
zh_Hans: 启用钉钉卡片流式回复模式
|
||||
description:
|
||||
en_US: If enabled, the bot will use the stream of lark reply mode
|
||||
zh_Hans: 如果启用,将使用钉钉卡片流式方式来回复内容
|
||||
type: boolean
|
||||
required: true
|
||||
default: false
|
||||
- name: card_template_id
|
||||
label:
|
||||
en_US: card template id
|
||||
zh_Hans: 卡片模板ID
|
||||
type: string
|
||||
required: true
|
||||
default: "填写你的卡片template_id"
|
||||
execution:
|
||||
python:
|
||||
path: ./dingtalk.py
|
||||
|
||||
@@ -8,7 +8,6 @@ import base64
|
||||
import uuid
|
||||
import os
|
||||
import datetime
|
||||
import io
|
||||
import asyncio
|
||||
from enum import Enum
|
||||
|
||||
|
||||
@@ -9,7 +9,6 @@ import re
|
||||
import base64
|
||||
import uuid
|
||||
import json
|
||||
import time
|
||||
import datetime
|
||||
import hashlib
|
||||
from Crypto.Cipher import AES
|
||||
@@ -18,6 +17,7 @@ import aiohttp
|
||||
import lark_oapi.ws.exception
|
||||
import quart
|
||||
from lark_oapi.api.im.v1 import *
|
||||
from lark_oapi.api.cardkit.v1 import *
|
||||
|
||||
from .. import adapter
|
||||
from ...core import app
|
||||
@@ -343,8 +343,11 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
config: dict
|
||||
quart_app: quart.Quart
|
||||
ap: app.Application
|
||||
|
||||
message_id_to_card_id: typing.Dict[str, typing.Tuple[str, int]]
|
||||
|
||||
|
||||
card_id_dict: dict[str, str] # 消息id到卡片id的映射,便于创建卡片后的发送消息到指定卡片
|
||||
|
||||
seq: int # 用于在发送卡片消息中识别消息顺序,直接以seq作为标识
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
||||
self.config = config
|
||||
@@ -352,7 +355,9 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
self.logger = logger
|
||||
self.quart_app = quart.Quart(__name__)
|
||||
self.listeners = {}
|
||||
self.message_id_to_card_id = {}
|
||||
self.card_id_dict = {}
|
||||
self.seq = 1
|
||||
|
||||
|
||||
@self.quart_app.route('/lark/callback', methods=['POST'])
|
||||
async def lark_callback():
|
||||
@@ -398,19 +403,6 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
return {'code': 500, 'message': 'error'}
|
||||
|
||||
async def on_message(event: lark_oapi.im.v1.P2ImMessageReceiveV1):
|
||||
if self.config['enable-card-reply'] and event.event.message.message_id not in self.message_id_to_card_id:
|
||||
self.ap.logger.debug('卡片回复模式开启')
|
||||
# 开启卡片回复模式. 这里可以实现飞书一发消息,马上创建卡片进行回复"思考中..."
|
||||
reply_message_id = await self.create_message_card(event.event.message.message_id)
|
||||
self.message_id_to_card_id[event.event.message.message_id] = (reply_message_id, time.time())
|
||||
|
||||
if len(self.message_id_to_card_id) > CARD_ID_CACHE_SIZE:
|
||||
self.message_id_to_card_id = {
|
||||
k: v
|
||||
for k, v in self.message_id_to_card_id.items()
|
||||
if v[1] > time.time() - CARD_ID_CACHE_MAX_LIFETIME
|
||||
}
|
||||
|
||||
lb_event = await self.event_converter.target2yiri(event, self.api_client)
|
||||
|
||||
await self.listeners[type(lb_event)](lb_event, self)
|
||||
@@ -430,21 +422,200 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
|
||||
pass
|
||||
|
||||
async def create_message_card(self, message_id: str) -> str:
|
||||
async def is_stream_output_supported(self) -> bool:
|
||||
is_stream = False
|
||||
if self.config.get('enable-stream-reply', None):
|
||||
is_stream = True
|
||||
return is_stream
|
||||
|
||||
async def create_card_id(self, message_id):
|
||||
try:
|
||||
self.ap.logger.debug('飞书支持stream输出,创建卡片......')
|
||||
|
||||
card_data = {"schema": "2.0", "config": {"update_multi": True, "streaming_mode": True,
|
||||
"streaming_config": {"print_step": {"default": 1},
|
||||
"print_frequency_ms": {"default": 70},
|
||||
"print_strategy": "fast"}},
|
||||
"body": {"direction": "vertical", "padding": "12px 12px 12px 12px", "elements": [{"tag": "div",
|
||||
"text": {
|
||||
"tag": "plain_text",
|
||||
"content": "LangBot",
|
||||
"text_size": "normal",
|
||||
"text_align": "left",
|
||||
"text_color": "default"},
|
||||
"icon": {
|
||||
"tag": "custom_icon",
|
||||
"img_key": "img_v3_02p3_05c65d5d-9bad-440a-a2fb-c89571bfd5bg"}},
|
||||
{
|
||||
"tag": "markdown",
|
||||
"content": "",
|
||||
"text_align": "left",
|
||||
"text_size": "normal",
|
||||
"margin": "0px 0px 0px 0px",
|
||||
"element_id": "streaming_txt"},
|
||||
{
|
||||
"tag": "markdown",
|
||||
"content": "",
|
||||
"text_align": "left",
|
||||
"text_size": "normal",
|
||||
"margin": "0px 0px 0px 0px"},
|
||||
{
|
||||
"tag": "column_set",
|
||||
"horizontal_spacing": "8px",
|
||||
"horizontal_align": "left",
|
||||
"columns": [
|
||||
{
|
||||
"tag": "column",
|
||||
"width": "weighted",
|
||||
"elements": [
|
||||
{
|
||||
"tag": "markdown",
|
||||
"content": "",
|
||||
"text_align": "left",
|
||||
"text_size": "normal",
|
||||
"margin": "0px 0px 0px 0px"},
|
||||
{
|
||||
"tag": "markdown",
|
||||
"content": "",
|
||||
"text_align": "left",
|
||||
"text_size": "normal",
|
||||
"margin": "0px 0px 0px 0px"},
|
||||
{
|
||||
"tag": "markdown",
|
||||
"content": "",
|
||||
"text_align": "left",
|
||||
"text_size": "normal",
|
||||
"margin": "0px 0px 0px 0px"}],
|
||||
"padding": "0px 0px 0px 0px",
|
||||
"direction": "vertical",
|
||||
"horizontal_spacing": "8px",
|
||||
"vertical_spacing": "2px",
|
||||
"horizontal_align": "left",
|
||||
"vertical_align": "top",
|
||||
"margin": "0px 0px 0px 0px",
|
||||
"weight": 1}],
|
||||
"margin": "0px 0px 0px 0px"},
|
||||
{"tag": "hr",
|
||||
"margin": "0px 0px 0px 0px"},
|
||||
{
|
||||
"tag": "column_set",
|
||||
"horizontal_spacing": "12px",
|
||||
"horizontal_align": "right",
|
||||
"columns": [
|
||||
{
|
||||
"tag": "column",
|
||||
"width": "weighted",
|
||||
"elements": [
|
||||
{
|
||||
"tag": "markdown",
|
||||
"content": "<font color=\"grey-600\">以上内容由 AI 生成,仅供参考。更多详细、准确信息可点击引用链接查看</font>",
|
||||
"text_align": "left",
|
||||
"text_size": "notation",
|
||||
"margin": "4px 0px 0px 0px",
|
||||
"icon": {
|
||||
"tag": "standard_icon",
|
||||
"token": "robot_outlined",
|
||||
"color": "grey"}}],
|
||||
"padding": "0px 0px 0px 0px",
|
||||
"direction": "vertical",
|
||||
"horizontal_spacing": "8px",
|
||||
"vertical_spacing": "8px",
|
||||
"horizontal_align": "left",
|
||||
"vertical_align": "top",
|
||||
"margin": "0px 0px 0px 0px",
|
||||
"weight": 1},
|
||||
{
|
||||
"tag": "column",
|
||||
"width": "20px",
|
||||
"elements": [
|
||||
{
|
||||
"tag": "button",
|
||||
"text": {
|
||||
"tag": "plain_text",
|
||||
"content": ""},
|
||||
"type": "text",
|
||||
"width": "fill",
|
||||
"size": "medium",
|
||||
"icon": {
|
||||
"tag": "standard_icon",
|
||||
"token": "thumbsup_outlined"},
|
||||
"hover_tips": {
|
||||
"tag": "plain_text",
|
||||
"content": "有帮助"},
|
||||
"margin": "0px 0px 0px 0px"}],
|
||||
"padding": "0px 0px 0px 0px",
|
||||
"direction": "vertical",
|
||||
"horizontal_spacing": "8px",
|
||||
"vertical_spacing": "8px",
|
||||
"horizontal_align": "left",
|
||||
"vertical_align": "top",
|
||||
"margin": "0px 0px 0px 0px"},
|
||||
{
|
||||
"tag": "column",
|
||||
"width": "30px",
|
||||
"elements": [
|
||||
{
|
||||
"tag": "button",
|
||||
"text": {
|
||||
"tag": "plain_text",
|
||||
"content": ""},
|
||||
"type": "text",
|
||||
"width": "default",
|
||||
"size": "medium",
|
||||
"icon": {
|
||||
"tag": "standard_icon",
|
||||
"token": "thumbdown_outlined"},
|
||||
"hover_tips": {
|
||||
"tag": "plain_text",
|
||||
"content": "无帮助"},
|
||||
"margin": "0px 0px 0px 0px"}],
|
||||
"padding": "0px 0px 0px 0px",
|
||||
"vertical_spacing": "8px",
|
||||
"horizontal_align": "left",
|
||||
"vertical_align": "top",
|
||||
"margin": "0px 0px 0px 0px"}],
|
||||
"margin": "0px 0px 4px 0px"}]}}
|
||||
# delay / fast 创建卡片模板,delay 延迟打印,fast 实时打印,可以自定义更好看的消息模板
|
||||
|
||||
request: CreateCardRequest = (
|
||||
CreateCardRequest.builder()
|
||||
.request_body(CreateCardRequestBody.builder().type('card_json').data(json.dumps(card_data)).build())
|
||||
.build()
|
||||
)
|
||||
|
||||
# 发起请求
|
||||
response: CreateCardResponse = self.api_client.cardkit.v1.card.create(request)
|
||||
|
||||
# 处理失败返回
|
||||
if not response.success():
|
||||
raise Exception(
|
||||
f'client.cardkit.v1.card.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
|
||||
)
|
||||
|
||||
self.ap.logger.debug(f'飞书卡片创建成功,卡片ID: {response.data.card_id}')
|
||||
self.card_id_dict[message_id] = response.data.card_id
|
||||
|
||||
card_id = response.data.card_id
|
||||
return card_id
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'飞书卡片创建失败,错误信息: {e}')
|
||||
|
||||
async def create_message_card(self, message_id, event) -> str:
|
||||
"""
|
||||
创建卡片消息。
|
||||
使用卡片消息是因为普通消息更新次数有限制,而大模型流式返回结果可能很多而超过限制,而飞书卡片没有这个限制
|
||||
使用卡片消息是因为普通消息更新次数有限制,而大模型流式返回结果可能很多而超过限制,而飞书卡片没有这个限制(api免费次数有限)
|
||||
"""
|
||||
# message_id = event.message_chain.message_id
|
||||
|
||||
# TODO 目前只支持卡片模板方式,且卡片变量一定是content,未来这块要做成可配置
|
||||
# 发消息马上就会回复显示初始化的content信息,即思考中
|
||||
card_id = await self.create_card_id(message_id)
|
||||
content = {
|
||||
'type': 'template',
|
||||
'data': {'template_id': self.config['card_template_id'], 'template_variable': {'content': 'Thinking...'}},
|
||||
}
|
||||
'type': 'card',
|
||||
'data': {'card_id': card_id, 'template_variable': {'content': 'Thinking...'}},
|
||||
} # 当收到消息时发送消息模板,可添加模板变量,详情查看飞书中接口文档
|
||||
request: ReplyMessageRequest = (
|
||||
ReplyMessageRequest.builder()
|
||||
.message_id(message_id)
|
||||
.message_id(event.message_chain.message_id)
|
||||
.request_body(
|
||||
ReplyMessageRequestBody.builder().content(json.dumps(content)).msg_type('interactive').build()
|
||||
)
|
||||
@@ -459,64 +630,13 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
raise Exception(
|
||||
f'client.im.v1.message.reply failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
|
||||
)
|
||||
return response.data.message_id
|
||||
return True
|
||||
|
||||
async def reply_message(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
if self.config['enable-card-reply']:
|
||||
await self.reply_card_message(message_source, message, quote_origin)
|
||||
else:
|
||||
await self.reply_normal_message(message_source, message, quote_origin)
|
||||
|
||||
async def reply_card_message(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
"""
|
||||
回复消息变成更新卡片消息
|
||||
"""
|
||||
lark_message = await self.message_converter.yiri2target(message, self.api_client)
|
||||
|
||||
text_message = ''
|
||||
for ele in lark_message[0]:
|
||||
if ele['tag'] == 'text':
|
||||
text_message += ele['text']
|
||||
elif ele['tag'] == 'md':
|
||||
text_message += ele['text']
|
||||
|
||||
content = {
|
||||
'type': 'template',
|
||||
'data': {'template_id': self.config['card_template_id'], 'template_variable': {'content': text_message}},
|
||||
}
|
||||
|
||||
request: PatchMessageRequest = (
|
||||
PatchMessageRequest.builder()
|
||||
.message_id(self.message_id_to_card_id[message_source.message_chain.message_id][0])
|
||||
.request_body(PatchMessageRequestBody.builder().content(json.dumps(content)).build())
|
||||
.build()
|
||||
)
|
||||
|
||||
# 发起请求
|
||||
response: PatchMessageResponse = self.api_client.im.v1.message.patch(request)
|
||||
|
||||
# 处理失败返回
|
||||
if not response.success():
|
||||
raise Exception(
|
||||
f'client.im.v1.message.patch failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
|
||||
)
|
||||
return
|
||||
|
||||
async def reply_normal_message(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
):
|
||||
# 不再需要了,因为message_id已经被包含到message_chain中
|
||||
# lark_event = await self.event_converter.yiri2target(message_source)
|
||||
@@ -550,6 +670,64 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
f'client.im.v1.message.reply failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
|
||||
)
|
||||
|
||||
async def reply_message_chunk(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
bot_message,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
is_final: bool = False,
|
||||
):
|
||||
"""
|
||||
回复消息变成更新卡片消息
|
||||
"""
|
||||
# self.seq += 1
|
||||
message_id = bot_message.resp_message_id
|
||||
msg_seq = bot_message.msg_sequence
|
||||
if msg_seq % 8 == 0 or is_final:
|
||||
|
||||
lark_message = await self.message_converter.yiri2target(message, self.api_client)
|
||||
|
||||
|
||||
text_message = ''
|
||||
for ele in lark_message[0]:
|
||||
if ele['tag'] == 'text':
|
||||
text_message += ele['text']
|
||||
elif ele['tag'] == 'md':
|
||||
text_message += ele['text']
|
||||
|
||||
# content = {
|
||||
# 'type': 'card_json',
|
||||
# 'data': {'card_id': self.card_id_dict[message_id], 'elements': {'content': text_message}},
|
||||
# }
|
||||
|
||||
request: ContentCardElementRequest = (
|
||||
ContentCardElementRequest.builder()
|
||||
.card_id(self.card_id_dict[message_id])
|
||||
.element_id('streaming_txt')
|
||||
.request_body(
|
||||
ContentCardElementRequestBody.builder()
|
||||
# .uuid("a0d69e20-1dd1-458b-k525-dfeca4015204")
|
||||
.content(text_message)
|
||||
.sequence(msg_seq)
|
||||
.build()
|
||||
)
|
||||
.build()
|
||||
)
|
||||
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
# self.seq = 1 # 消息回复结束之后重置seq
|
||||
self.card_id_dict.pop(message_id) # 清理已经使用过的卡片
|
||||
# 发起请求
|
||||
response: ContentCardElementResponse = self.api_client.cardkit.v1.card_element.content(request)
|
||||
|
||||
# 处理失败返回
|
||||
if not response.success():
|
||||
raise Exception(
|
||||
f'client.im.v1.message.patch failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
|
||||
)
|
||||
return
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
return False
|
||||
|
||||
@@ -600,4 +778,4 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
|
||||
# 所以要设置_auto_reconnect=False,让其不重连。
|
||||
self.bot._auto_reconnect = False
|
||||
await self.bot._disconnect()
|
||||
return False
|
||||
return False
|
||||
@@ -65,23 +65,16 @@ spec:
|
||||
type: string
|
||||
required: true
|
||||
default: ""
|
||||
- name: enable-card-reply
|
||||
- name: enable-stream-reply
|
||||
label:
|
||||
en_US: Enable Card Reply Mode
|
||||
zh_Hans: 启用飞书卡片回复模式
|
||||
en_US: Enable Stream Reply Mode
|
||||
zh_Hans: 启用飞书流式回复模式
|
||||
description:
|
||||
en_US: If enabled, the bot will use the card of lark reply mode
|
||||
zh_Hans: 如果启用,将使用飞书卡片方式来回复内容
|
||||
en_US: If enabled, the bot will use the stream of lark reply mode
|
||||
zh_Hans: 如果启用,将使用飞书流式方式来回复内容
|
||||
type: boolean
|
||||
required: true
|
||||
default: false
|
||||
- name: card_template_id
|
||||
label:
|
||||
en_US: card template id
|
||||
zh_Hans: 卡片模板ID
|
||||
type: string
|
||||
required: true
|
||||
default: "填写你的卡片template_id"
|
||||
execution:
|
||||
python:
|
||||
path: ./lark.py
|
||||
|
||||
@@ -501,7 +501,7 @@ class OfficialAdapter(adapter_model.MessagePlatformAdapter):
|
||||
for event_handler in event_handler_mapping[event_type]:
|
||||
setattr(self.bot, event_handler, wrapper)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in qqbotpy callback: {traceback.format_exc()}")
|
||||
self.logger.error(f'Error in qqbotpy callback: {traceback.format_exc()}')
|
||||
raise e
|
||||
|
||||
def unregister_listener(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
import telegram
|
||||
import telegram.ext
|
||||
from telegram import Update
|
||||
@@ -143,6 +144,10 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
|
||||
config: dict
|
||||
ap: app.Application
|
||||
|
||||
msg_stream_id: dict # 流式消息id字典,key为流式消息id,value为首次消息源id,用于在流式消息时判断编辑那条消息
|
||||
|
||||
seq: int # 消息中识别消息顺序,直接以seq作为标识
|
||||
|
||||
listeners: typing.Dict[
|
||||
typing.Type[platform_events.Event],
|
||||
typing.Callable[[platform_events.Event, adapter.MessagePlatformAdapter], None],
|
||||
@@ -152,6 +157,8 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
|
||||
self.config = config
|
||||
self.ap = ap
|
||||
self.logger = logger
|
||||
self.msg_stream_id = {}
|
||||
# self.seq = 1
|
||||
|
||||
async def telegram_callback(update: Update, context: ContextTypes.DEFAULT_TYPE):
|
||||
if update.message.from_user.is_bot:
|
||||
@@ -160,6 +167,7 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
|
||||
try:
|
||||
lb_event = await self.event_converter.target2yiri(update, self.bot, self.bot_account_id)
|
||||
await self.listeners[type(lb_event)](lb_event, self)
|
||||
await self.is_stream_output_supported()
|
||||
except Exception:
|
||||
await self.logger.error(f'Error in telegram callback: {traceback.format_exc()}')
|
||||
|
||||
@@ -200,6 +208,70 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
|
||||
|
||||
await self.bot.send_message(**args)
|
||||
|
||||
async def reply_message_chunk(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
bot_message,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
is_final: bool = False,
|
||||
):
|
||||
msg_seq = bot_message.msg_sequence
|
||||
if (msg_seq - 1) % 8 == 0 or is_final:
|
||||
assert isinstance(message_source.source_platform_object, Update)
|
||||
components = await TelegramMessageConverter.yiri2target(message, self.bot)
|
||||
args = {}
|
||||
message_id = message_source.source_platform_object.message.id
|
||||
if quote_origin:
|
||||
args['reply_to_message_id'] = message_source.source_platform_object.message.id
|
||||
|
||||
component = components[0]
|
||||
if message_id not in self.msg_stream_id: # 当消息回复第一次时,发送新消息
|
||||
# time.sleep(0.6)
|
||||
if component['type'] == 'text':
|
||||
if self.config['markdown_card'] is True:
|
||||
content = telegramify_markdown.markdownify(
|
||||
content=component['text'],
|
||||
)
|
||||
else:
|
||||
content = component['text']
|
||||
args = {
|
||||
'chat_id': message_source.source_platform_object.effective_chat.id,
|
||||
'text': content,
|
||||
}
|
||||
if self.config['markdown_card'] is True:
|
||||
args['parse_mode'] = 'MarkdownV2'
|
||||
|
||||
send_msg = await self.bot.send_message(**args)
|
||||
send_msg_id = send_msg.message_id
|
||||
self.msg_stream_id[message_id] = send_msg_id
|
||||
else: # 存在消息的时候直接编辑消息1
|
||||
if component['type'] == 'text':
|
||||
if self.config['markdown_card'] is True:
|
||||
content = telegramify_markdown.markdownify(
|
||||
content=component['text'],
|
||||
)
|
||||
else:
|
||||
content = component['text']
|
||||
args = {
|
||||
'message_id': self.msg_stream_id[message_id],
|
||||
'chat_id': message_source.source_platform_object.effective_chat.id,
|
||||
'text': content,
|
||||
}
|
||||
if self.config['markdown_card'] is True:
|
||||
args['parse_mode'] = 'MarkdownV2'
|
||||
|
||||
await self.bot.edit_message_text(**args)
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
# self.seq = 1 # 消息回复结束之后重置seq
|
||||
self.msg_stream_id.pop(message_id) # 消息回复结束之后删除流式消息id
|
||||
|
||||
async def is_stream_output_supported(self) -> bool:
|
||||
is_stream = False
|
||||
if self.config.get('enable-stream-reply', None):
|
||||
is_stream = True
|
||||
return is_stream
|
||||
|
||||
async def is_muted(self, group_id: int) -> bool:
|
||||
return False
|
||||
|
||||
@@ -222,8 +294,12 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
|
||||
self.bot_account_id = (await self.bot.get_me()).username
|
||||
await self.application.updater.start_polling(allowed_updates=Update.ALL_TYPES)
|
||||
await self.application.start()
|
||||
await self.logger.info('Telegram adapter running')
|
||||
|
||||
async def kill(self) -> bool:
|
||||
if self.application.running:
|
||||
await self.application.stop()
|
||||
if self.application.updater:
|
||||
await self.application.updater.stop()
|
||||
await self.logger.info('Telegram adapter stopped')
|
||||
return True
|
||||
|
||||
@@ -25,6 +25,16 @@ spec:
|
||||
type: boolean
|
||||
required: false
|
||||
default: true
|
||||
- name: enable-stream-reply
|
||||
label:
|
||||
en_US: Enable Stream Reply Mode
|
||||
zh_Hans: 启用电报流式回复模式
|
||||
description:
|
||||
en_US: If enabled, the bot will use the stream of telegram reply mode
|
||||
zh_Hans: 如果启用,将使用电报流式方式来回复内容
|
||||
type: boolean
|
||||
required: true
|
||||
default: false
|
||||
execution:
|
||||
python:
|
||||
path: ./telegram.py
|
||||
|
||||
@@ -19,17 +19,20 @@ class WebChatMessage(BaseModel):
|
||||
content: str
|
||||
message_chain: list[dict]
|
||||
timestamp: str
|
||||
is_final: bool = False
|
||||
|
||||
|
||||
class WebChatSession:
|
||||
id: str
|
||||
message_lists: dict[str, list[WebChatMessage]] = {}
|
||||
resp_waiters: dict[int, asyncio.Future[WebChatMessage]]
|
||||
resp_queues: dict[int, asyncio.Queue[WebChatMessage]]
|
||||
|
||||
def __init__(self, id: str):
|
||||
self.id = id
|
||||
self.message_lists = {}
|
||||
self.resp_waiters = {}
|
||||
self.resp_queues = {}
|
||||
|
||||
def get_message_list(self, pipeline_uuid: str) -> list[WebChatMessage]:
|
||||
if pipeline_uuid not in self.message_lists:
|
||||
@@ -49,6 +52,8 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter):
|
||||
typing.Callable[[platform_events.Event, msadapter.MessagePlatformAdapter], None],
|
||||
] = {}
|
||||
|
||||
is_stream: bool
|
||||
|
||||
def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
|
||||
self.ap = ap
|
||||
self.logger = logger
|
||||
@@ -59,6 +64,8 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter):
|
||||
|
||||
self.bot_account_id = 'webchatbot'
|
||||
|
||||
self.is_stream = False
|
||||
|
||||
async def send_message(
|
||||
self,
|
||||
target_type: str,
|
||||
@@ -102,12 +109,53 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter):
|
||||
|
||||
# notify waiter
|
||||
if isinstance(message_source, platform_events.FriendMessage):
|
||||
self.webchat_person_session.resp_waiters[message_source.message_chain.message_id].set_result(message_data)
|
||||
await self.webchat_person_session.resp_queues[message_source.message_chain.message_id].put(message_data)
|
||||
elif isinstance(message_source, platform_events.GroupMessage):
|
||||
self.webchat_group_session.resp_waiters[message_source.message_chain.message_id].set_result(message_data)
|
||||
await self.webchat_group_session.resp_queues[message_source.message_chain.message_id].put(message_data)
|
||||
|
||||
return message_data.model_dump()
|
||||
|
||||
async def reply_message_chunk(
|
||||
self,
|
||||
message_source: platform_events.MessageEvent,
|
||||
bot_message,
|
||||
message: platform_message.MessageChain,
|
||||
quote_origin: bool = False,
|
||||
is_final: bool = False,
|
||||
) -> dict:
|
||||
"""回复消息"""
|
||||
message_data = WebChatMessage(
|
||||
id=-1,
|
||||
role='assistant',
|
||||
content=str(message),
|
||||
message_chain=[component.__dict__ for component in message],
|
||||
timestamp=datetime.now().isoformat(),
|
||||
)
|
||||
|
||||
# notify waiter
|
||||
session = (
|
||||
self.webchat_group_session
|
||||
if isinstance(message_source, platform_events.GroupMessage)
|
||||
else self.webchat_person_session
|
||||
)
|
||||
if message_source.message_chain.message_id not in session.resp_waiters:
|
||||
# session.resp_waiters[message_source.message_chain.message_id] = asyncio.Queue()
|
||||
queue = session.resp_queues[message_source.message_chain.message_id]
|
||||
|
||||
# if isinstance(message_source, platform_events.FriendMessage):
|
||||
# queue = self.webchat_person_session.resp_queues[message_source.message_chain.message_id]
|
||||
# elif isinstance(message_source, platform_events.GroupMessage):
|
||||
# queue = self.webchat_group_session.resp_queues[message_source.message_chain.message_id]
|
||||
if is_final and bot_message.tool_calls is None:
|
||||
message_data.is_final = True
|
||||
# print(message_data)
|
||||
await queue.put(message_data)
|
||||
|
||||
return message_data.model_dump()
|
||||
|
||||
async def is_stream_output_supported(self) -> bool:
|
||||
return self.is_stream
|
||||
|
||||
def register_listener(
|
||||
self,
|
||||
event_type: typing.Type[platform_events.Event],
|
||||
@@ -140,8 +188,13 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter):
|
||||
await self.logger.info('WebChat调试适配器正在停止')
|
||||
|
||||
async def send_webchat_message(
|
||||
self, pipeline_uuid: str, session_type: str, message_chain_obj: typing.List[dict]
|
||||
self,
|
||||
pipeline_uuid: str,
|
||||
session_type: str,
|
||||
message_chain_obj: typing.List[dict],
|
||||
is_stream: bool = False,
|
||||
) -> dict:
|
||||
self.is_stream = is_stream
|
||||
"""发送调试消息到流水线"""
|
||||
if session_type == 'person':
|
||||
use_session = self.webchat_person_session
|
||||
@@ -152,6 +205,9 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter):
|
||||
|
||||
message_id = len(use_session.get_message_list(pipeline_uuid)) + 1
|
||||
|
||||
use_session.resp_queues[message_id] = asyncio.Queue()
|
||||
logger.debug(f'Initialized queue for message_id: {message_id}')
|
||||
|
||||
use_session.get_message_list(pipeline_uuid).append(
|
||||
WebChatMessage(
|
||||
id=message_id,
|
||||
@@ -185,21 +241,46 @@ class WebChatAdapter(msadapter.MessagePlatformAdapter):
|
||||
|
||||
self.ap.platform_mgr.webchat_proxy_bot.bot_entity.use_pipeline_uuid = pipeline_uuid
|
||||
|
||||
# trigger pipeline
|
||||
if event.__class__ in self.listeners:
|
||||
await self.listeners[event.__class__](event, self)
|
||||
|
||||
# set waiter
|
||||
waiter = asyncio.Future[WebChatMessage]()
|
||||
use_session.resp_waiters[message_id] = waiter
|
||||
waiter.add_done_callback(lambda future: use_session.resp_waiters.pop(message_id))
|
||||
if is_stream:
|
||||
queue = use_session.resp_queues[message_id]
|
||||
msg_id = len(use_session.get_message_list(pipeline_uuid)) + 1
|
||||
while True:
|
||||
resp_message = await queue.get()
|
||||
resp_message.id = msg_id
|
||||
if resp_message.is_final:
|
||||
resp_message.id = msg_id
|
||||
use_session.get_message_list(pipeline_uuid).append(resp_message)
|
||||
yield resp_message.model_dump()
|
||||
break
|
||||
yield resp_message.model_dump()
|
||||
use_session.resp_queues.pop(message_id)
|
||||
|
||||
resp_message = await waiter
|
||||
else: # non-stream
|
||||
# set waiter
|
||||
# waiter = asyncio.Future[WebChatMessage]()
|
||||
# use_session.resp_waiters[message_id] = waiter
|
||||
# # waiter.add_done_callback(lambda future: use_session.resp_waiters.pop(message_id))
|
||||
#
|
||||
# resp_message = await waiter
|
||||
#
|
||||
# resp_message.id = len(use_session.get_message_list(pipeline_uuid)) + 1
|
||||
#
|
||||
# use_session.get_message_list(pipeline_uuid).append(resp_message)
|
||||
#
|
||||
# yield resp_message.model_dump()
|
||||
msg_id = len(use_session.get_message_list(pipeline_uuid)) + 1
|
||||
|
||||
resp_message.id = len(use_session.get_message_list(pipeline_uuid)) + 1
|
||||
queue = use_session.resp_queues[message_id]
|
||||
resp_message = await queue.get()
|
||||
use_session.get_message_list(pipeline_uuid).append(resp_message)
|
||||
resp_message.id = msg_id
|
||||
resp_message.is_final = True
|
||||
|
||||
use_session.get_message_list(pipeline_uuid).append(resp_message)
|
||||
|
||||
return resp_message.model_dump()
|
||||
yield resp_message.model_dump()
|
||||
|
||||
def get_webchat_messages(self, pipeline_uuid: str, session_type: str) -> list[dict]:
|
||||
"""获取调试消息历史"""
|
||||
|
||||
@@ -9,7 +9,8 @@ metadata:
|
||||
en_US: "WebChat adapter for pipeline debugging"
|
||||
zh_Hans: "用于流水线调试的网页聊天适配器"
|
||||
icon: ""
|
||||
spec: {}
|
||||
spec:
|
||||
config: []
|
||||
execution:
|
||||
python:
|
||||
path: "webchat.py"
|
||||
|
||||
@@ -241,8 +241,8 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
# self.logger.info("_handler_compound_quote", ET.tostring(xml_data, encoding='unicode'))
|
||||
appmsg_data = xml_data.find('.//appmsg')
|
||||
quote_data = '' # 引用原文
|
||||
quote_id = None # 引用消息的原发送者
|
||||
tousername = None # 接收方: 所属微信的wxid
|
||||
# quote_id = None # 引用消息的原发送者
|
||||
# tousername = None # 接收方: 所属微信的wxid
|
||||
user_data = '' # 用户消息
|
||||
sender_id = xml_data.findtext('.//fromusername') # 发送方:单聊用户/群member
|
||||
|
||||
@@ -250,13 +250,10 @@ class WeChatPadMessageConverter(adapter.MessageConverter):
|
||||
if appmsg_data:
|
||||
user_data = appmsg_data.findtext('.//title') or ''
|
||||
quote_data = appmsg_data.find('.//refermsg').findtext('.//content')
|
||||
quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr')
|
||||
# quote_id = appmsg_data.find('.//refermsg').findtext('.//chatusr')
|
||||
message_list.append(platform_message.WeChatAppMsg(app_msg=ET.tostring(appmsg_data, encoding='unicode')))
|
||||
if message:
|
||||
tousername = message['to_user_name']['str']
|
||||
|
||||
_ = quote_id
|
||||
_ = tousername
|
||||
# if message:
|
||||
# tousername = message['to_user_name']['str']
|
||||
|
||||
if quote_data:
|
||||
quote_data_message_list = platform_message.MessageChain()
|
||||
|
||||
@@ -812,12 +812,14 @@ class File(MessageComponent):
|
||||
def __str__(self):
|
||||
return f'[文件]{self.name}'
|
||||
|
||||
|
||||
class Face(MessageComponent):
|
||||
"""系统表情
|
||||
此处将超级表情骰子/划拳,一同归类于face
|
||||
当face_type为rps(划拳)时 face_id 对应的是手势
|
||||
当face_type为dice(骰子)时 face_id 对应的是点数
|
||||
"""
|
||||
|
||||
type: str = 'Face'
|
||||
"""表情类型"""
|
||||
face_type: str = 'face'
|
||||
@@ -834,15 +836,15 @@ class Face(MessageComponent):
|
||||
elif self.face_type == 'rps':
|
||||
return f'[表情]{self.face_name}({self.rps_data(self.face_id)})'
|
||||
|
||||
|
||||
def rps_data(self,face_id):
|
||||
rps_dict ={
|
||||
1 : "布",
|
||||
2 : "剪刀",
|
||||
3 : "石头",
|
||||
def rps_data(self, face_id):
|
||||
rps_dict = {
|
||||
1: '布',
|
||||
2: '剪刀',
|
||||
3: '石头',
|
||||
}
|
||||
return rps_dict[face_id]
|
||||
|
||||
|
||||
# ================ 个人微信专用组件 ================
|
||||
|
||||
|
||||
@@ -971,5 +973,6 @@ class WeChatFile(MessageComponent):
|
||||
"""文件地址"""
|
||||
file_base64: str = ''
|
||||
"""base64"""
|
||||
|
||||
def __str__(self):
|
||||
return f'[文件]{self.file_name}'
|
||||
return f'[文件]{self.file_name}'
|
||||
|
||||
@@ -125,6 +125,95 @@ class Message(pydantic.BaseModel):
|
||||
return platform_message.MessageChain(mc)
|
||||
|
||||
|
||||
class MessageChunk(pydantic.BaseModel):
|
||||
"""消息"""
|
||||
|
||||
resp_message_id: typing.Optional[str] = None
|
||||
"""消息id"""
|
||||
|
||||
role: str # user, system, assistant, tool, command, plugin
|
||||
"""消息的角色"""
|
||||
|
||||
name: typing.Optional[str] = None
|
||||
"""名称,仅函数调用返回时设置"""
|
||||
|
||||
all_content: typing.Optional[str] = None
|
||||
"""所有内容"""
|
||||
|
||||
content: typing.Optional[list[ContentElement]] | typing.Optional[str] = None
|
||||
"""内容"""
|
||||
|
||||
tool_calls: typing.Optional[list[ToolCall]] = None
|
||||
"""工具调用"""
|
||||
|
||||
tool_call_id: typing.Optional[str] = None
|
||||
|
||||
is_final: bool = False
|
||||
"""是否是结束"""
|
||||
|
||||
msg_sequence: int = 0
|
||||
"""消息迭代次数"""
|
||||
|
||||
def readable_str(self) -> str:
|
||||
if self.content is not None:
|
||||
return str(self.role) + ': ' + str(self.get_content_platform_message_chain())
|
||||
elif self.tool_calls is not None:
|
||||
return f'调用工具: {self.tool_calls[0].id}'
|
||||
else:
|
||||
return '未知消息'
|
||||
|
||||
def get_content_platform_message_chain(self, prefix_text: str = '') -> platform_message.MessageChain | None:
|
||||
"""将内容转换为平台消息 MessageChain 对象
|
||||
|
||||
Args:
|
||||
prefix_text (str): 首个文字组件的前缀文本
|
||||
"""
|
||||
|
||||
if self.content is None:
|
||||
return None
|
||||
elif isinstance(self.content, str):
|
||||
return platform_message.MessageChain([platform_message.Plain(prefix_text + self.content)])
|
||||
elif isinstance(self.content, list):
|
||||
mc = []
|
||||
for ce in self.content:
|
||||
if ce.type == 'text':
|
||||
mc.append(platform_message.Plain(ce.text))
|
||||
elif ce.type == 'image_url':
|
||||
if ce.image_url.url.startswith('http'):
|
||||
mc.append(platform_message.Image(url=ce.image_url.url))
|
||||
else: # base64
|
||||
b64_str = ce.image_url.url
|
||||
|
||||
if b64_str.startswith('data:'):
|
||||
b64_str = b64_str.split(',')[1]
|
||||
|
||||
mc.append(platform_message.Image(base64=b64_str))
|
||||
|
||||
# 找第一个文字组件
|
||||
if prefix_text:
|
||||
for i, c in enumerate(mc):
|
||||
if isinstance(c, platform_message.Plain):
|
||||
mc[i] = platform_message.Plain(prefix_text + c.text)
|
||||
break
|
||||
else:
|
||||
mc.insert(0, platform_message.Plain(prefix_text))
|
||||
|
||||
return platform_message.MessageChain(mc)
|
||||
|
||||
|
||||
class ToolCallChunk(pydantic.BaseModel):
|
||||
"""工具调用"""
|
||||
|
||||
id: str
|
||||
"""工具调用ID"""
|
||||
|
||||
type: str
|
||||
"""工具调用类型"""
|
||||
|
||||
function: FunctionCall
|
||||
"""函数调用"""
|
||||
|
||||
|
||||
class Prompt(pydantic.BaseModel):
|
||||
"""供AI使用的Prompt"""
|
||||
|
||||
|
||||
@@ -84,6 +84,7 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> llm_entities.Message:
|
||||
"""调用API
|
||||
|
||||
@@ -92,12 +93,36 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
||||
messages (typing.List[llm_entities.Message]): 消息对象列表
|
||||
funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None.
|
||||
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.
|
||||
remove_think (bool, optional): 是否移思考中的消息. Defaults to False.
|
||||
|
||||
Returns:
|
||||
llm_entities.Message: 返回消息对象
|
||||
"""
|
||||
pass
|
||||
|
||||
async def invoke_llm_stream(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
model: RuntimeLLMModel,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> llm_entities.MessageChunk:
|
||||
"""调用API
|
||||
|
||||
Args:
|
||||
model (RuntimeLLMModel): 使用的模型信息
|
||||
messages (typing.List[llm_entities.Message]): 消息对象列表
|
||||
funcs (typing.List[tools_entities.LLMFunction], optional): 使用的工具函数列表. Defaults to None.
|
||||
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.
|
||||
remove_think (bool, optional): 是否移除思考中的消息. Defaults to False.
|
||||
|
||||
Returns:
|
||||
typing.AsyncGenerator[llm_entities.MessageChunk]: 返回消息对象
|
||||
"""
|
||||
pass
|
||||
|
||||
async def invoke_embedding(
|
||||
self,
|
||||
model: RuntimeEmbeddingModel,
|
||||
|
||||
@@ -21,7 +21,7 @@ class AnthropicMessages(requester.ProviderAPIRequester):
|
||||
client: anthropic.AsyncAnthropic
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
'base_url': 'https://api.anthropic.com/v1',
|
||||
'base_url': 'https://api.anthropic.com',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
@@ -44,6 +44,7 @@ class AnthropicMessages(requester.ProviderAPIRequester):
|
||||
self.client = anthropic.AsyncAnthropic(
|
||||
api_key='',
|
||||
http_client=httpx_client,
|
||||
base_url=self.requester_cfg['base_url'],
|
||||
)
|
||||
|
||||
async def invoke_llm(
|
||||
@@ -53,6 +54,7 @@ class AnthropicMessages(requester.ProviderAPIRequester):
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = model.token_mgr.get_token()
|
||||
|
||||
@@ -89,7 +91,8 @@ class AnthropicMessages(requester.ProviderAPIRequester):
|
||||
{
|
||||
'type': 'tool_result',
|
||||
'tool_use_id': tool_call_id,
|
||||
'content': m.content,
|
||||
'is_error': False,
|
||||
'content': [{'type': 'text', 'text': m.content}],
|
||||
}
|
||||
],
|
||||
}
|
||||
@@ -133,6 +136,9 @@ class AnthropicMessages(requester.ProviderAPIRequester):
|
||||
|
||||
args['messages'] = req_messages
|
||||
|
||||
if 'thinking' in args:
|
||||
args['thinking'] = {'type': 'enabled', 'budget_tokens': 10000}
|
||||
|
||||
if funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_anthropic(funcs)
|
||||
|
||||
@@ -140,19 +146,17 @@ class AnthropicMessages(requester.ProviderAPIRequester):
|
||||
args['tools'] = tools
|
||||
|
||||
try:
|
||||
# print(json.dumps(args, indent=4, ensure_ascii=False))
|
||||
resp = await self.client.messages.create(**args)
|
||||
|
||||
args = {
|
||||
'content': '',
|
||||
'role': resp.role,
|
||||
}
|
||||
|
||||
assert type(resp) is anthropic.types.message.Message
|
||||
|
||||
for block in resp.content:
|
||||
if block.type == 'thinking':
|
||||
args['content'] = '<think>' + block.thinking + '</think>\n' + args['content']
|
||||
if not remove_think and block.type == 'thinking':
|
||||
args['content'] = '<think>\n' + block.thinking + '\n</think>\n' + args['content']
|
||||
elif block.type == 'text':
|
||||
args['content'] += block.text
|
||||
elif block.type == 'tool_use':
|
||||
@@ -176,3 +180,191 @@ class AnthropicMessages(requester.ProviderAPIRequester):
|
||||
raise errors.RequesterError(f'模型无效: {e.message}')
|
||||
else:
|
||||
raise errors.RequesterError(f'请求地址无效: {e.message}')
|
||||
|
||||
async def invoke_llm_stream(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
model: requester.RuntimeLLMModel,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = model.token_mgr.get_token()
|
||||
|
||||
args = extra_args.copy()
|
||||
args['model'] = model.model_entity.name
|
||||
args['stream'] = True
|
||||
|
||||
# 处理消息
|
||||
|
||||
# system
|
||||
system_role_message = None
|
||||
|
||||
for i, m in enumerate(messages):
|
||||
if m.role == 'system':
|
||||
system_role_message = m
|
||||
|
||||
break
|
||||
|
||||
if system_role_message:
|
||||
messages.pop(i)
|
||||
|
||||
if isinstance(system_role_message, llm_entities.Message) and isinstance(system_role_message.content, str):
|
||||
args['system'] = system_role_message.content
|
||||
|
||||
req_messages = []
|
||||
|
||||
for m in messages:
|
||||
if m.role == 'tool':
|
||||
tool_call_id = m.tool_call_id
|
||||
|
||||
req_messages.append(
|
||||
{
|
||||
'role': 'user',
|
||||
'content': [
|
||||
{
|
||||
'type': 'tool_result',
|
||||
'tool_use_id': tool_call_id,
|
||||
'is_error': False, # 暂时直接写false
|
||||
'content': [
|
||||
{'type': 'text', 'text': m.content}
|
||||
], # 这里要是list包裹,应该是多个返回的情况?type类型好像也可以填其他的,暂时只写text
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
continue
|
||||
|
||||
msg_dict = m.dict(exclude_none=True)
|
||||
|
||||
if isinstance(m.content, str) and m.content.strip() != '':
|
||||
msg_dict['content'] = [{'type': 'text', 'text': m.content}]
|
||||
elif isinstance(m.content, list):
|
||||
for i, ce in enumerate(m.content):
|
||||
if ce.type == 'image_base64':
|
||||
image_b64, image_format = await image.extract_b64_and_format(ce.image_base64)
|
||||
|
||||
alter_image_ele = {
|
||||
'type': 'image',
|
||||
'source': {
|
||||
'type': 'base64',
|
||||
'media_type': f'image/{image_format}',
|
||||
'data': image_b64,
|
||||
},
|
||||
}
|
||||
msg_dict['content'][i] = alter_image_ele
|
||||
if isinstance(msg_dict['content'], str) and msg_dict['content'] == '':
|
||||
msg_dict['content'] = [] # 这里不知道为什么会莫名有个空导致content为字符
|
||||
if m.tool_calls:
|
||||
for tool_call in m.tool_calls:
|
||||
msg_dict['content'].append(
|
||||
{
|
||||
'type': 'tool_use',
|
||||
'id': tool_call.id,
|
||||
'name': tool_call.function.name,
|
||||
'input': json.loads(tool_call.function.arguments),
|
||||
}
|
||||
)
|
||||
|
||||
del msg_dict['tool_calls']
|
||||
|
||||
req_messages.append(msg_dict)
|
||||
if 'thinking' in args:
|
||||
args['thinking'] = {'type': 'enabled', 'budget_tokens': 10000}
|
||||
|
||||
args['messages'] = req_messages
|
||||
|
||||
if funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_anthropic(funcs)
|
||||
|
||||
if tools:
|
||||
args['tools'] = tools
|
||||
|
||||
try:
|
||||
role = 'assistant' # 默认角色
|
||||
# chunk_idx = 0
|
||||
think_started = False
|
||||
think_ended = False
|
||||
finish_reason = False
|
||||
content = ''
|
||||
tool_name = ''
|
||||
tool_id = ''
|
||||
async for chunk in await self.client.messages.create(**args):
|
||||
tool_call = {'id': None, 'function': {'name': None, 'arguments': None}, 'type': 'function'}
|
||||
if isinstance(
|
||||
chunk, anthropic.types.raw_content_block_start_event.RawContentBlockStartEvent
|
||||
): # 记录开始
|
||||
if chunk.content_block.type == 'tool_use':
|
||||
if chunk.content_block.name is not None:
|
||||
tool_name = chunk.content_block.name
|
||||
if chunk.content_block.id is not None:
|
||||
tool_id = chunk.content_block.id
|
||||
|
||||
tool_call['function']['name'] = tool_name
|
||||
tool_call['function']['arguments'] = ''
|
||||
tool_call['id'] = tool_id
|
||||
|
||||
if not remove_think:
|
||||
if chunk.content_block.type == 'thinking' and not remove_think:
|
||||
think_started = True
|
||||
elif chunk.content_block.type == 'text' and chunk.index != 0 and not remove_think:
|
||||
think_ended = True
|
||||
continue
|
||||
elif isinstance(chunk, anthropic.types.raw_content_block_delta_event.RawContentBlockDeltaEvent):
|
||||
if chunk.delta.type == 'thinking_delta':
|
||||
if think_started:
|
||||
think_started = False
|
||||
content = '<think>\n' + chunk.delta.thinking
|
||||
elif remove_think:
|
||||
continue
|
||||
else:
|
||||
content = chunk.delta.thinking
|
||||
elif chunk.delta.type == 'text_delta':
|
||||
if think_ended:
|
||||
think_ended = False
|
||||
content = '\n</think>\n' + chunk.delta.text
|
||||
else:
|
||||
content = chunk.delta.text
|
||||
elif chunk.delta.type == 'input_json_delta':
|
||||
tool_call['function']['arguments'] = chunk.delta.partial_json
|
||||
tool_call['function']['name'] = tool_name
|
||||
tool_call['id'] = tool_id
|
||||
elif isinstance(chunk, anthropic.types.raw_content_block_stop_event.RawContentBlockStopEvent):
|
||||
continue # 记录raw_content_block结束的
|
||||
|
||||
elif isinstance(chunk, anthropic.types.raw_message_delta_event.RawMessageDeltaEvent):
|
||||
if chunk.delta.stop_reason == 'end_turn':
|
||||
finish_reason = True
|
||||
elif isinstance(chunk, anthropic.types.raw_message_stop_event.RawMessageStopEvent):
|
||||
continue # 这个好像是完全结束
|
||||
else:
|
||||
# print(chunk)
|
||||
self.ap.logger.debug(f'anthropic chunk: {chunk}')
|
||||
continue
|
||||
|
||||
args = {
|
||||
'content': content,
|
||||
'role': role,
|
||||
'is_final': finish_reason,
|
||||
'tool_calls': None if tool_call['id'] is None else [tool_call],
|
||||
}
|
||||
# if chunk_idx == 0:
|
||||
# chunk_idx += 1
|
||||
# continue
|
||||
|
||||
# assert type(chunk) is anthropic.types.message.Chunk
|
||||
|
||||
yield llm_entities.MessageChunk(**args)
|
||||
|
||||
# return llm_entities.Message(**args)
|
||||
except anthropic.AuthenticationError as e:
|
||||
raise errors.RequesterError(f'api-key 无效: {e.message}')
|
||||
except anthropic.BadRequestError as e:
|
||||
raise errors.RequesterError(str(e.message))
|
||||
except anthropic.NotFoundError as e:
|
||||
if 'model: ' in str(e):
|
||||
raise errors.RequesterError(f'模型无效: {e.message}')
|
||||
else:
|
||||
raise errors.RequesterError(f'请求地址无效: {e.message}')
|
||||
|
||||
@@ -14,7 +14,7 @@ spec:
|
||||
zh_Hans: 基础 URL
|
||||
type: string
|
||||
required: true
|
||||
default: "https://api.anthropic.com/v1"
|
||||
default: "https://api.anthropic.com"
|
||||
- name: timeout
|
||||
label:
|
||||
en_US: Timeout
|
||||
|
||||
@@ -38,9 +38,18 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
) -> chat_completion.ChatCompletion:
|
||||
return await self.client.chat.completions.create(**args, extra_body=extra_body)
|
||||
|
||||
async def _req_stream(
|
||||
self,
|
||||
args: dict,
|
||||
extra_body: dict = {},
|
||||
):
|
||||
async for chunk in await self.client.chat.completions.create(**args, extra_body=extra_body):
|
||||
yield chunk
|
||||
|
||||
async def _make_msg(
|
||||
self,
|
||||
chat_completion: chat_completion.ChatCompletion,
|
||||
remove_think: bool = False,
|
||||
) -> llm_entities.Message:
|
||||
chatcmpl_message = chat_completion.choices[0].message.model_dump()
|
||||
|
||||
@@ -48,16 +57,192 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None:
|
||||
chatcmpl_message['role'] = 'assistant'
|
||||
|
||||
reasoning_content = chatcmpl_message['reasoning_content'] if 'reasoning_content' in chatcmpl_message else None
|
||||
# 处理思维链
|
||||
content = chatcmpl_message.get('content', '')
|
||||
reasoning_content = chatcmpl_message.get('reasoning_content', None)
|
||||
|
||||
# deepseek的reasoner模型
|
||||
if reasoning_content is not None:
|
||||
chatcmpl_message['content'] = '<think>\n' + reasoning_content + '\n</think>\n' + chatcmpl_message['content']
|
||||
processed_content, _ = await self._process_thinking_content(
|
||||
content=content, reasoning_content=reasoning_content, remove_think=remove_think
|
||||
)
|
||||
|
||||
chatcmpl_message['content'] = processed_content
|
||||
|
||||
# 移除 reasoning_content 字段,避免传递给 Message
|
||||
if 'reasoning_content' in chatcmpl_message:
|
||||
del chatcmpl_message['reasoning_content']
|
||||
|
||||
message = llm_entities.Message(**chatcmpl_message)
|
||||
|
||||
return message
|
||||
|
||||
async def _process_thinking_content(
|
||||
self,
|
||||
content: str,
|
||||
reasoning_content: str = None,
|
||||
remove_think: bool = False,
|
||||
) -> tuple[str, str]:
|
||||
"""处理思维链内容
|
||||
|
||||
Args:
|
||||
content: 原始内容
|
||||
reasoning_content: reasoning_content 字段内容
|
||||
remove_think: 是否移除思维链
|
||||
|
||||
Returns:
|
||||
(处理后的内容, 提取的思维链内容)
|
||||
"""
|
||||
thinking_content = ''
|
||||
|
||||
# 1. 从 reasoning_content 提取思维链
|
||||
if reasoning_content:
|
||||
thinking_content = reasoning_content
|
||||
|
||||
# 2. 从 content 中提取 <think> 标签内容
|
||||
if content and '<think>' in content and '</think>' in content:
|
||||
import re
|
||||
|
||||
think_pattern = r'<think>(.*?)</think>'
|
||||
think_matches = re.findall(think_pattern, content, re.DOTALL)
|
||||
if think_matches:
|
||||
# 如果已有 reasoning_content,则追加
|
||||
if thinking_content:
|
||||
thinking_content += '\n' + '\n'.join(think_matches)
|
||||
else:
|
||||
thinking_content = '\n'.join(think_matches)
|
||||
# 移除 content 中的 <think> 标签
|
||||
content = re.sub(think_pattern, '', content, flags=re.DOTALL).strip()
|
||||
|
||||
# 3. 根据 remove_think 参数决定是否保留思维链
|
||||
if remove_think:
|
||||
return content, ''
|
||||
else:
|
||||
# 如果有思维链内容,将其以 <think> 格式添加到 content 开头
|
||||
if thinking_content:
|
||||
content = f'<think>\n{thinking_content}\n</think>\n{content}'.strip()
|
||||
return content, thinking_content
|
||||
|
||||
async def _closure_stream(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
req_messages: list[dict],
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> llm_entities.MessageChunk:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = {}
|
||||
args['model'] = use_model.model_entity.name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
if tools:
|
||||
args['tools'] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages.copy()
|
||||
|
||||
# 检查vision
|
||||
for msg in messages:
|
||||
if 'content' in msg and isinstance(msg['content'], list):
|
||||
for me in msg['content']:
|
||||
if me['type'] == 'image_base64':
|
||||
me['image_url'] = {'url': me['image_base64']}
|
||||
me['type'] = 'image_url'
|
||||
del me['image_base64']
|
||||
|
||||
args['messages'] = messages
|
||||
args['stream'] = True
|
||||
|
||||
# 流式处理状态
|
||||
tool_calls_map: dict[str, llm_entities.ToolCall] = {}
|
||||
chunk_idx = 0
|
||||
thinking_started = False
|
||||
thinking_ended = False
|
||||
role = 'assistant' # 默认角色
|
||||
tool_id = ""
|
||||
tool_name = ''
|
||||
# accumulated_reasoning = '' # 仅用于判断何时结束思维链
|
||||
|
||||
async for chunk in self._req_stream(args, extra_body=extra_args):
|
||||
# 解析 chunk 数据
|
||||
|
||||
if hasattr(chunk, 'choices') and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {}
|
||||
|
||||
finish_reason = getattr(choice, 'finish_reason', None)
|
||||
else:
|
||||
delta = {}
|
||||
finish_reason = None
|
||||
# 从第一个 chunk 获取 role,后续使用这个 role
|
||||
if 'role' in delta and delta['role']:
|
||||
role = delta['role']
|
||||
|
||||
# 获取增量内容
|
||||
delta_content = delta.get('content', '')
|
||||
reasoning_content = delta.get('reasoning_content', '')
|
||||
|
||||
# 处理 reasoning_content
|
||||
if reasoning_content:
|
||||
# accumulated_reasoning += reasoning_content
|
||||
# 如果设置了 remove_think,跳过 reasoning_content
|
||||
if remove_think:
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
# 第一次出现 reasoning_content,添加 <think> 开始标签
|
||||
if not thinking_started:
|
||||
thinking_started = True
|
||||
delta_content = '<think>\n' + reasoning_content
|
||||
else:
|
||||
# 继续输出 reasoning_content
|
||||
delta_content = reasoning_content
|
||||
elif thinking_started and not thinking_ended and delta_content:
|
||||
# reasoning_content 结束,normal content 开始,添加 </think> 结束标签
|
||||
thinking_ended = True
|
||||
delta_content = '\n</think>\n' + delta_content
|
||||
|
||||
# 处理 content 中已有的 <think> 标签(如果需要移除)
|
||||
# if delta_content and remove_think and '<think>' in delta_content:
|
||||
# import re
|
||||
#
|
||||
# # 移除 <think> 标签及其内容
|
||||
# delta_content = re.sub(r'<think>.*?</think>', '', delta_content, flags=re.DOTALL)
|
||||
|
||||
# 处理工具调用增量
|
||||
# delta_tool_calls = None
|
||||
if delta.get('tool_calls'):
|
||||
for tool_call in delta['tool_calls']:
|
||||
if tool_call['id'] and tool_call['function']['name']:
|
||||
tool_id = tool_call['id']
|
||||
tool_name = tool_call['function']['name']
|
||||
else:
|
||||
tool_call['id'] = tool_id
|
||||
tool_call['function']['name'] = tool_name
|
||||
if tool_call['type'] is None:
|
||||
tool_call['type'] = 'function'
|
||||
|
||||
|
||||
|
||||
# 跳过空的第一个 chunk(只有 role 没有内容)
|
||||
if chunk_idx == 0 and not delta_content and not reasoning_content and not delta.get('tool_calls'):
|
||||
chunk_idx += 1
|
||||
continue
|
||||
# 构建 MessageChunk - 只包含增量内容
|
||||
chunk_data = {
|
||||
'role': role,
|
||||
'content': delta_content if delta_content else None,
|
||||
'tool_calls': delta.get('tool_calls'),
|
||||
'is_final': bool(finish_reason),
|
||||
}
|
||||
|
||||
# 移除 None 值
|
||||
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
||||
|
||||
yield llm_entities.MessageChunk(**chunk_data)
|
||||
chunk_idx += 1
|
||||
|
||||
async def _closure(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
@@ -65,6 +250,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
@@ -92,10 +278,10 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
args['messages'] = messages
|
||||
|
||||
# 发送请求
|
||||
resp = await self._req(args, extra_body=extra_args)
|
||||
|
||||
resp = await self._req(args, extra_body=extra_args)
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp)
|
||||
message = await self._make_msg(resp, remove_think)
|
||||
|
||||
return message
|
||||
|
||||
@@ -106,6 +292,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> llm_entities.Message:
|
||||
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
|
||||
for m in messages:
|
||||
@@ -119,13 +306,15 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
req_messages.append(msg_dict)
|
||||
|
||||
try:
|
||||
return await self._closure(
|
||||
msg = await self._closure(
|
||||
query=query,
|
||||
req_messages=req_messages,
|
||||
use_model=model,
|
||||
use_funcs=funcs,
|
||||
extra_args=extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
return msg
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
except openai.BadRequestError as e:
|
||||
@@ -169,6 +358,45 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
raise errors.RequesterError('请求超时')
|
||||
except openai.BadRequestError as e:
|
||||
raise errors.RequesterError(f'请求参数错误: {e.message}')
|
||||
|
||||
async def invoke_llm_stream(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
model: requester.RuntimeLLMModel,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> llm_entities.MessageChunk:
|
||||
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
|
||||
for m in messages:
|
||||
msg_dict = m.dict(exclude_none=True)
|
||||
content = msg_dict.get('content')
|
||||
if isinstance(content, list):
|
||||
# 检查 content 列表中是否每个部分都是文本
|
||||
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
|
||||
# 将所有文本部分合并为一个字符串
|
||||
msg_dict['content'] = '\n'.join(part['text'] for part in content)
|
||||
req_messages.append(msg_dict)
|
||||
|
||||
try:
|
||||
async for item in self._closure_stream(
|
||||
query=query,
|
||||
req_messages=req_messages,
|
||||
use_model=model,
|
||||
use_funcs=funcs,
|
||||
extra_args=extra_args,
|
||||
remove_think=remove_think,
|
||||
):
|
||||
yield item
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
except openai.BadRequestError as e:
|
||||
if 'context_length_exceeded' in e.message:
|
||||
raise errors.RequesterError(f'上文过长,请重置会话: {e.message}')
|
||||
else:
|
||||
raise errors.RequesterError(f'请求参数错误: {e.message}')
|
||||
except openai.AuthenticationError as e:
|
||||
raise errors.RequesterError(f'无效的 api-key: {e.message}')
|
||||
except openai.NotFoundError as e:
|
||||
|
||||
@@ -24,6 +24,7 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
@@ -49,10 +50,11 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
# 发送请求
|
||||
resp = await self._req(args, extra_body=extra_args)
|
||||
|
||||
# print(resp)
|
||||
|
||||
if resp is None:
|
||||
raise errors.RequesterError('接口返回为空,请确定模型提供商服务是否正常')
|
||||
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp)
|
||||
message = await self._make_msg(resp, remove_think)
|
||||
|
||||
return message
|
||||
|
||||
@@ -3,14 +3,16 @@ from __future__ import annotations
|
||||
|
||||
import typing
|
||||
|
||||
from . import chatcmpl
|
||||
from . import ppiochatcmpl
|
||||
from .. import requester
|
||||
from ....core import entities as core_entities
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
import re
|
||||
import openai.types.chat.chat_completion as chat_completion
|
||||
|
||||
|
||||
class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
class GiteeAIChatCompletions(ppiochatcmpl.PPIOChatCompletions):
|
||||
"""Gitee AI ChatCompletions API 请求器"""
|
||||
|
||||
default_config: dict[str, typing.Any] = {
|
||||
@@ -18,34 +20,3 @@ class GiteeAIChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
async def _closure(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
req_messages: list[dict],
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = {}
|
||||
args['model'] = use_model.model_entity.name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args['tools'] = tools
|
||||
|
||||
# gitee 不支持多模态,把content都转换成纯文字
|
||||
for m in req_messages:
|
||||
if 'content' in m and isinstance(m['content'], list):
|
||||
m['content'] = ' '.join([c['text'] for c in m['content']])
|
||||
|
||||
args['messages'] = req_messages
|
||||
|
||||
resp = await self._req(args, extra_body=extra_args)
|
||||
|
||||
message = await self._make_msg(resp)
|
||||
|
||||
return message
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import typing
|
||||
|
||||
import openai
|
||||
@@ -34,9 +35,11 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
|
||||
|
||||
async def _req(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
args: dict,
|
||||
extra_body: dict = {},
|
||||
) -> chat_completion.ChatCompletion:
|
||||
remove_think: bool = False,
|
||||
) -> list[dict[str, typing.Any]]:
|
||||
args['stream'] = True
|
||||
|
||||
chunk = None
|
||||
@@ -47,78 +50,77 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
|
||||
|
||||
resp_gen: openai.AsyncStream = await self.client.chat.completions.create(**args, extra_body=extra_body)
|
||||
|
||||
chunk_idx = 0
|
||||
thinking_started = False
|
||||
thinking_ended = False
|
||||
tool_id = ''
|
||||
tool_name = ''
|
||||
message_delta = {}
|
||||
async for chunk in resp_gen:
|
||||
# print(chunk)
|
||||
if not chunk or not chunk.id or not chunk.choices or not chunk.choices[0] or not chunk.choices[0].delta:
|
||||
continue
|
||||
|
||||
if chunk.choices[0].delta.content is not None:
|
||||
pending_content += chunk.choices[0].delta.content
|
||||
delta = chunk.choices[0].delta.model_dump() if hasattr(chunk.choices[0], 'delta') else {}
|
||||
reasoning_content = delta.get('reasoning_content')
|
||||
# 处理 reasoning_content
|
||||
if reasoning_content:
|
||||
# accumulated_reasoning += reasoning_content
|
||||
# 如果设置了 remove_think,跳过 reasoning_content
|
||||
if remove_think:
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
if chunk.choices[0].delta.tool_calls is not None:
|
||||
for tool_call in chunk.choices[0].delta.tool_calls:
|
||||
if tool_call.function.arguments is None:
|
||||
# 第一次出现 reasoning_content,添加 <think> 开始标签
|
||||
if not thinking_started:
|
||||
thinking_started = True
|
||||
pending_content += '<think>\n' + reasoning_content
|
||||
else:
|
||||
# 继续输出 reasoning_content
|
||||
pending_content += reasoning_content
|
||||
elif thinking_started and not thinking_ended and delta.get('content'):
|
||||
# reasoning_content 结束,normal content 开始,添加 </think> 结束标签
|
||||
thinking_ended = True
|
||||
pending_content += '\n</think>\n' + delta.get('content')
|
||||
|
||||
if delta.get('content') is not None:
|
||||
pending_content += delta.get('content')
|
||||
|
||||
if delta.get('tool_calls') is not None:
|
||||
for tool_call in delta.get('tool_calls'):
|
||||
if tool_call['id'] != '':
|
||||
tool_id = tool_call['id']
|
||||
if tool_call['function']['name'] is not None:
|
||||
tool_name = tool_call['function']['name']
|
||||
if tool_call['function']['arguments'] is None:
|
||||
continue
|
||||
tool_call['id'] = tool_id
|
||||
tool_call['name'] = tool_name
|
||||
for tc in tool_calls:
|
||||
if tc.index == tool_call.index:
|
||||
tc.function.arguments += tool_call.function.arguments
|
||||
if tc['index'] == tool_call['index']:
|
||||
tc['function']['arguments'] += tool_call['function']['arguments']
|
||||
break
|
||||
else:
|
||||
tool_calls.append(tool_call)
|
||||
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
break
|
||||
message_delta['content'] = pending_content
|
||||
message_delta['role'] = 'assistant'
|
||||
|
||||
real_tool_calls = []
|
||||
|
||||
for tc in tool_calls:
|
||||
function = chat_completion_message_tool_call.Function(
|
||||
name=tc.function.name, arguments=tc.function.arguments
|
||||
)
|
||||
real_tool_calls.append(
|
||||
chat_completion_message_tool_call.ChatCompletionMessageToolCall(
|
||||
id=tc.id, function=function, type='function'
|
||||
)
|
||||
)
|
||||
|
||||
return (
|
||||
chat_completion.ChatCompletion(
|
||||
id=chunk.id,
|
||||
object='chat.completion',
|
||||
created=chunk.created,
|
||||
choices=[
|
||||
chat_completion.Choice(
|
||||
index=0,
|
||||
message=chat_completion.ChatCompletionMessage(
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
tool_calls=real_tool_calls if len(real_tool_calls) > 0 else None,
|
||||
),
|
||||
finish_reason=chunk.choices[0].finish_reason
|
||||
if hasattr(chunk.choices[0], 'finish_reason') and chunk.choices[0].finish_reason is not None
|
||||
else 'stop',
|
||||
logprobs=chunk.choices[0].logprobs,
|
||||
)
|
||||
],
|
||||
model=chunk.model,
|
||||
service_tier=chunk.service_tier if hasattr(chunk, 'service_tier') else None,
|
||||
system_fingerprint=chunk.system_fingerprint if hasattr(chunk, 'system_fingerprint') else None,
|
||||
usage=chunk.usage if hasattr(chunk, 'usage') else None,
|
||||
)
|
||||
if chunk
|
||||
else None
|
||||
)
|
||||
message_delta['tool_calls'] = tool_calls if tool_calls else None
|
||||
# print(message_delta)
|
||||
return [message_delta]
|
||||
|
||||
async def _make_msg(
|
||||
self,
|
||||
chat_completion: chat_completion.ChatCompletion,
|
||||
chat_completion: list[dict[str, typing.Any]],
|
||||
) -> llm_entities.Message:
|
||||
chatcmpl_message = chat_completion.choices[0].message.dict()
|
||||
chatcmpl_message = chat_completion[0]
|
||||
|
||||
# 确保 role 字段存在且不为 None
|
||||
if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None:
|
||||
chatcmpl_message['role'] = 'assistant'
|
||||
|
||||
print(chatcmpl_message)
|
||||
message = llm_entities.Message(**chatcmpl_message)
|
||||
|
||||
return message
|
||||
@@ -130,6 +132,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think:bool = False,
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
@@ -157,13 +160,146 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
|
||||
args['messages'] = messages
|
||||
|
||||
# 发送请求
|
||||
resp = await self._req(args, extra_body=extra_args)
|
||||
resp = await self._req(query, args, extra_body=extra_args, remove_think=remove_think)
|
||||
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp)
|
||||
|
||||
return message
|
||||
|
||||
async def _req_stream(
|
||||
self,
|
||||
args: dict,
|
||||
extra_body: dict = {},
|
||||
) -> chat_completion.ChatCompletion:
|
||||
async for chunk in await self.client.chat.completions.create(**args, extra_body=extra_body):
|
||||
yield chunk
|
||||
|
||||
|
||||
async def _closure_stream(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
req_messages: list[dict],
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = {}
|
||||
args['model'] = use_model.model_entity.name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args['tools'] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages.copy()
|
||||
|
||||
# 检查vision
|
||||
for msg in messages:
|
||||
if 'content' in msg and isinstance(msg['content'], list):
|
||||
for me in msg['content']:
|
||||
if me['type'] == 'image_base64':
|
||||
me['image_url'] = {'url': me['image_base64']}
|
||||
me['type'] = 'image_url'
|
||||
del me['image_base64']
|
||||
|
||||
args['messages'] = messages
|
||||
args['stream'] = True
|
||||
|
||||
|
||||
# 流式处理状态
|
||||
tool_calls_map: dict[str, llm_entities.ToolCall] = {}
|
||||
chunk_idx = 0
|
||||
thinking_started = False
|
||||
thinking_ended = False
|
||||
role = 'assistant' # 默认角色
|
||||
# accumulated_reasoning = '' # 仅用于判断何时结束思维链
|
||||
|
||||
async for chunk in self._req_stream(args, extra_body=extra_args):
|
||||
# 解析 chunk 数据
|
||||
if hasattr(chunk, 'choices') and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {}
|
||||
finish_reason = getattr(choice, 'finish_reason', None)
|
||||
else:
|
||||
delta = {}
|
||||
finish_reason = None
|
||||
|
||||
# 从第一个 chunk 获取 role,后续使用这个 role
|
||||
if 'role' in delta and delta['role']:
|
||||
role = delta['role']
|
||||
|
||||
# 获取增量内容
|
||||
delta_content = delta.get('content', '')
|
||||
reasoning_content = delta.get('reasoning_content', '')
|
||||
|
||||
# 处理 reasoning_content
|
||||
if reasoning_content:
|
||||
# accumulated_reasoning += reasoning_content
|
||||
# 如果设置了 remove_think,跳过 reasoning_content
|
||||
if remove_think:
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
# 第一次出现 reasoning_content,添加 <think> 开始标签
|
||||
if not thinking_started:
|
||||
thinking_started = True
|
||||
delta_content = '<think>\n' + reasoning_content
|
||||
else:
|
||||
# 继续输出 reasoning_content
|
||||
delta_content = reasoning_content
|
||||
elif thinking_started and not thinking_ended and delta_content:
|
||||
# reasoning_content 结束,normal content 开始,添加 </think> 结束标签
|
||||
thinking_ended = True
|
||||
delta_content = '\n</think>\n' + delta_content
|
||||
|
||||
# 处理 content 中已有的 <think> 标签(如果需要移除)
|
||||
# if delta_content and remove_think and '<think>' in delta_content:
|
||||
# import re
|
||||
#
|
||||
# # 移除 <think> 标签及其内容
|
||||
# delta_content = re.sub(r'<think>.*?</think>', '', delta_content, flags=re.DOTALL)
|
||||
|
||||
# 处理工具调用增量
|
||||
if delta.get('tool_calls'):
|
||||
for tool_call in delta['tool_calls']:
|
||||
if tool_call['id'] != '':
|
||||
tool_id = tool_call['id']
|
||||
if tool_call['function']['name'] is not None:
|
||||
tool_name = tool_call['function']['name']
|
||||
|
||||
if tool_call['type'] is None:
|
||||
tool_call['type'] = 'function'
|
||||
tool_call['id'] = tool_id
|
||||
tool_call['function']['name'] = tool_name
|
||||
tool_call['function']['arguments'] = "" if tool_call['function']['arguments'] is None else tool_call['function']['arguments']
|
||||
|
||||
|
||||
# 跳过空的第一个 chunk(只有 role 没有内容)
|
||||
if chunk_idx == 0 and not delta_content and not reasoning_content and not delta.get('tool_calls'):
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
# 构建 MessageChunk - 只包含增量内容
|
||||
chunk_data = {
|
||||
'role': role,
|
||||
'content': delta_content if delta_content else None,
|
||||
'tool_calls': delta.get('tool_calls'),
|
||||
'is_final': bool(finish_reason),
|
||||
}
|
||||
|
||||
# 移除 None 值
|
||||
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
||||
|
||||
yield llm_entities.MessageChunk(**chunk_data)
|
||||
chunk_idx += 1
|
||||
# return
|
||||
|
||||
async def invoke_llm(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
@@ -171,6 +307,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> llm_entities.Message:
|
||||
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
|
||||
for m in messages:
|
||||
@@ -185,7 +322,7 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
|
||||
|
||||
try:
|
||||
return await self._closure(
|
||||
query=query, req_messages=req_messages, use_model=model, use_funcs=funcs, extra_args=extra_args
|
||||
query=query, req_messages=req_messages, use_model=model, use_funcs=funcs, extra_args=extra_args, remove_think=remove_think
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
@@ -202,3 +339,50 @@ class ModelScopeChatCompletions(requester.ProviderAPIRequester):
|
||||
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
|
||||
except openai.APIError as e:
|
||||
raise errors.RequesterError(f'请求错误: {e.message}')
|
||||
|
||||
async def invoke_llm_stream(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
model: requester.RuntimeLLMModel,
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> llm_entities.MessageChunk:
|
||||
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
|
||||
for m in messages:
|
||||
msg_dict = m.dict(exclude_none=True)
|
||||
content = msg_dict.get('content')
|
||||
if isinstance(content, list):
|
||||
# 检查 content 列表中是否每个部分都是文本
|
||||
if all(isinstance(part, dict) and part.get('type') == 'text' for part in content):
|
||||
# 将所有文本部分合并为一个字符串
|
||||
msg_dict['content'] = '\n'.join(part['text'] for part in content)
|
||||
req_messages.append(msg_dict)
|
||||
|
||||
try:
|
||||
async for item in self._closure_stream(
|
||||
query=query,
|
||||
req_messages=req_messages,
|
||||
use_model=model,
|
||||
use_funcs=funcs,
|
||||
extra_args=extra_args,
|
||||
remove_think=remove_think,
|
||||
):
|
||||
yield item
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
except openai.BadRequestError as e:
|
||||
if 'context_length_exceeded' in e.message:
|
||||
raise errors.RequesterError(f'上文过长,请重置会话: {e.message}')
|
||||
else:
|
||||
raise errors.RequesterError(f'请求参数错误: {e.message}')
|
||||
except openai.AuthenticationError as e:
|
||||
raise errors.RequesterError(f'无效的 api-key: {e.message}')
|
||||
except openai.NotFoundError as e:
|
||||
raise errors.RequesterError(f'请求路径错误: {e.message}')
|
||||
except openai.RateLimitError as e:
|
||||
raise errors.RequesterError(f'请求过于频繁或余额不足: {e.message}')
|
||||
except openai.APIError as e:
|
||||
raise errors.RequesterError(f'请求错误: {e.message}')
|
||||
|
||||
@@ -25,6 +25,7 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> llm_entities.Message:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
@@ -54,6 +55,6 @@ class MoonshotChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
resp = await self._req(args, extra_body=extra_args)
|
||||
|
||||
# 处理请求结果
|
||||
message = await self._make_msg(resp)
|
||||
message = await self._make_msg(resp, remove_think)
|
||||
|
||||
return message
|
||||
|
||||
@@ -44,6 +44,7 @@ class OllamaChatCompletions(requester.ProviderAPIRequester):
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> llm_entities.Message:
|
||||
args = extra_args.copy()
|
||||
args['model'] = use_model.model_entity.name
|
||||
@@ -110,6 +111,7 @@ class OllamaChatCompletions(requester.ProviderAPIRequester):
|
||||
messages: typing.List[llm_entities.Message],
|
||||
funcs: typing.List[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> llm_entities.Message:
|
||||
req_messages: list = []
|
||||
for m in messages:
|
||||
@@ -126,6 +128,7 @@ class OllamaChatCompletions(requester.ProviderAPIRequester):
|
||||
use_model=model,
|
||||
use_funcs=funcs,
|
||||
extra_args=extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
|
||||
@@ -4,6 +4,12 @@ import openai
|
||||
import typing
|
||||
|
||||
from . import chatcmpl
|
||||
import openai.types.chat.chat_completion as chat_completion
|
||||
from .. import requester
|
||||
from ....core import entities as core_entities
|
||||
from ... import entities as llm_entities
|
||||
from ...tools import entities as tools_entities
|
||||
import re
|
||||
|
||||
|
||||
class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
@@ -15,3 +21,193 @@ class PPIOChatCompletions(chatcmpl.OpenAIChatCompletions):
|
||||
'base_url': 'https://api.ppinfra.com/v3/openai',
|
||||
'timeout': 120,
|
||||
}
|
||||
|
||||
is_think: bool = False
|
||||
|
||||
async def _make_msg(
|
||||
self,
|
||||
chat_completion: chat_completion.ChatCompletion,
|
||||
remove_think: bool,
|
||||
) -> llm_entities.Message:
|
||||
chatcmpl_message = chat_completion.choices[0].message.model_dump()
|
||||
# print(chatcmpl_message.keys(), chatcmpl_message.values())
|
||||
|
||||
# 确保 role 字段存在且不为 None
|
||||
if 'role' not in chatcmpl_message or chatcmpl_message['role'] is None:
|
||||
chatcmpl_message['role'] = 'assistant'
|
||||
|
||||
reasoning_content = chatcmpl_message['reasoning_content'] if 'reasoning_content' in chatcmpl_message else None
|
||||
|
||||
# deepseek的reasoner模型
|
||||
chatcmpl_message["content"] = await self._process_thinking_content(
|
||||
chatcmpl_message['content'],reasoning_content,remove_think)
|
||||
|
||||
# 移除 reasoning_content 字段,避免传递给 Message
|
||||
if 'reasoning_content' in chatcmpl_message:
|
||||
del chatcmpl_message['reasoning_content']
|
||||
|
||||
|
||||
message = llm_entities.Message(**chatcmpl_message)
|
||||
|
||||
return message
|
||||
|
||||
async def _process_thinking_content(
|
||||
self,
|
||||
content: str,
|
||||
reasoning_content: str = None,
|
||||
remove_think: bool = False,
|
||||
) -> tuple[str, str]:
|
||||
"""处理思维链内容
|
||||
|
||||
Args:
|
||||
content: 原始内容
|
||||
reasoning_content: reasoning_content 字段内容
|
||||
remove_think: 是否移除思维链
|
||||
|
||||
Returns:
|
||||
处理后的内容
|
||||
"""
|
||||
if remove_think:
|
||||
content = re.sub(
|
||||
r'<think>.*?</think>', '', content, flags=re.DOTALL
|
||||
)
|
||||
else:
|
||||
if reasoning_content is not None:
|
||||
content = (
|
||||
'<think>\n' + reasoning_content + '\n</think>\n' + content
|
||||
)
|
||||
return content
|
||||
|
||||
async def _make_msg_chunk(
|
||||
self,
|
||||
delta: dict[str, typing.Any],
|
||||
idx: int,
|
||||
) -> llm_entities.MessageChunk:
|
||||
# 处理流式chunk和完整响应的差异
|
||||
# print(chat_completion.choices[0])
|
||||
|
||||
# 确保 role 字段存在且不为 None
|
||||
if 'role' not in delta or delta['role'] is None:
|
||||
delta['role'] = 'assistant'
|
||||
|
||||
reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None
|
||||
|
||||
delta['content'] = '' if delta['content'] is None else delta['content']
|
||||
# print(reasoning_content)
|
||||
|
||||
# deepseek的reasoner模型
|
||||
|
||||
if reasoning_content is not None:
|
||||
delta['content'] += reasoning_content
|
||||
|
||||
message = llm_entities.MessageChunk(**delta)
|
||||
|
||||
return message
|
||||
|
||||
async def _closure_stream(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
req_messages: list[dict],
|
||||
use_model: requester.RuntimeLLMModel,
|
||||
use_funcs: list[tools_entities.LLMFunction] = None,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
remove_think: bool = False,
|
||||
) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]:
|
||||
self.client.api_key = use_model.token_mgr.get_token()
|
||||
|
||||
args = {}
|
||||
args['model'] = use_model.model_entity.name
|
||||
|
||||
if use_funcs:
|
||||
tools = await self.ap.tool_mgr.generate_tools_for_openai(use_funcs)
|
||||
|
||||
if tools:
|
||||
args['tools'] = tools
|
||||
|
||||
# 设置此次请求中的messages
|
||||
messages = req_messages.copy()
|
||||
|
||||
# 检查vision
|
||||
for msg in messages:
|
||||
if 'content' in msg and isinstance(msg['content'], list):
|
||||
for me in msg['content']:
|
||||
if me['type'] == 'image_base64':
|
||||
me['image_url'] = {'url': me['image_base64']}
|
||||
me['type'] = 'image_url'
|
||||
del me['image_base64']
|
||||
|
||||
args['messages'] = messages
|
||||
args['stream'] = True
|
||||
|
||||
tool_calls_map: dict[str, llm_entities.ToolCall] = {}
|
||||
chunk_idx = 0
|
||||
thinking_started = False
|
||||
thinking_ended = False
|
||||
role = 'assistant' # 默认角色
|
||||
async for chunk in self._req_stream(args, extra_body=extra_args):
|
||||
# 解析 chunk 数据
|
||||
if hasattr(chunk, 'choices') and chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta.model_dump() if hasattr(choice, 'delta') else {}
|
||||
finish_reason = getattr(choice, 'finish_reason', None)
|
||||
else:
|
||||
delta = {}
|
||||
finish_reason = None
|
||||
|
||||
# 从第一个 chunk 获取 role,后续使用这个 role
|
||||
if 'role' in delta and delta['role']:
|
||||
role = delta['role']
|
||||
|
||||
# 获取增量内容
|
||||
delta_content = delta.get('content', '')
|
||||
# reasoning_content = delta.get('reasoning_content', '')
|
||||
|
||||
if remove_think:
|
||||
if delta['content'] is not None:
|
||||
if '<think>' in delta['content'] and not thinking_started and not thinking_ended:
|
||||
thinking_started = True
|
||||
continue
|
||||
elif delta['content'] == r'</think>' and not thinking_ended:
|
||||
thinking_ended = True
|
||||
continue
|
||||
elif thinking_ended and delta['content'] == '\n\n' and thinking_started:
|
||||
thinking_started = False
|
||||
continue
|
||||
elif thinking_started and not thinking_ended:
|
||||
continue
|
||||
|
||||
|
||||
delta_tool_calls = None
|
||||
if delta.get('tool_calls'):
|
||||
for tool_call in delta['tool_calls']:
|
||||
if tool_call['id'] and tool_call['function']['name']:
|
||||
tool_id = tool_call['id']
|
||||
tool_name = tool_call['function']['name']
|
||||
|
||||
if tool_call['id'] is None:
|
||||
tool_call['id'] = tool_id
|
||||
if tool_call['function']['name'] is None:
|
||||
tool_call['function']['name'] = tool_name
|
||||
if tool_call['function']['arguments'] is None:
|
||||
tool_call['function']['arguments'] = ''
|
||||
if tool_call['type'] is None:
|
||||
tool_call['type'] = 'function'
|
||||
|
||||
# 跳过空的第一个 chunk(只有 role 没有内容)
|
||||
if chunk_idx == 0 and not delta_content and not delta.get('tool_calls'):
|
||||
chunk_idx += 1
|
||||
continue
|
||||
|
||||
# 构建 MessageChunk - 只包含增量内容
|
||||
chunk_data = {
|
||||
'role': role,
|
||||
'content': delta_content if delta_content else None,
|
||||
'tool_calls': delta.get('tool_calls'),
|
||||
'is_final': bool(finish_reason),
|
||||
}
|
||||
|
||||
# 移除 None 值
|
||||
chunk_data = {k: v for k, v in chunk_data.items() if v is not None}
|
||||
|
||||
yield llm_entities.MessageChunk(**chunk_data)
|
||||
chunk_idx += 1
|
||||
|
||||
@@ -99,8 +99,14 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
plain_text = '' # 用户输入的纯文本信息
|
||||
image_ids = [] # 用户输入的图片ID列表 (暂不支持)
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
think_start = False
|
||||
think_end = False
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
has_thoughts = True # 获取思考过程
|
||||
remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think')
|
||||
if remove_think:
|
||||
has_thoughts = False
|
||||
# 发送对话请求
|
||||
response = dashscope.Application.call(
|
||||
api_key=self.api_key, # 智能体应用的API Key
|
||||
@@ -109,43 +115,109 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
stream=True, # 流式输出
|
||||
incremental_output=True, # 增量输出,使用流式输出需要开启增量输出
|
||||
session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话
|
||||
has_thoughts=has_thoughts,
|
||||
# rag_options={ # 主要用于文件交互,暂不支持
|
||||
# "session_file_ids": ["FILE_ID1"], # FILE_ID1 替换为实际的临时文件ID,逗号隔开多个
|
||||
# }
|
||||
)
|
||||
idx_chunk = 0
|
||||
try:
|
||||
# print(await query.adapter.is_stream_output_supported())
|
||||
is_stream = await query.adapter.is_stream_output_supported()
|
||||
|
||||
for chunk in response:
|
||||
if chunk.get('status_code') != 200:
|
||||
raise DashscopeAPIError(
|
||||
f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
|
||||
)
|
||||
if not chunk:
|
||||
continue
|
||||
except AttributeError:
|
||||
is_stream = False
|
||||
if is_stream:
|
||||
for chunk in response:
|
||||
if chunk.get('status_code') != 200:
|
||||
raise DashscopeAPIError(
|
||||
f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
|
||||
)
|
||||
if not chunk:
|
||||
continue
|
||||
idx_chunk += 1
|
||||
# 获取流式传输的output
|
||||
stream_output = chunk.get('output', {})
|
||||
stream_think = stream_output.get('thoughts', [])
|
||||
if stream_think[0].get('thought'):
|
||||
if not think_start:
|
||||
think_start = True
|
||||
pending_content += f"<think>\n{stream_think[0].get('thought')}"
|
||||
else:
|
||||
# 继续输出 reasoning_content
|
||||
pending_content += stream_think[0].get('thought')
|
||||
elif stream_think[0].get('thought') == "" and not think_end:
|
||||
think_end = True
|
||||
pending_content += "\n</think>\n"
|
||||
if stream_output.get('text') is not None:
|
||||
pending_content += stream_output.get('text')
|
||||
# 是否是流式最后一个chunk
|
||||
is_final = False if stream_output.get('finish_reason', False) == 'null' else True
|
||||
|
||||
# 获取流式传输的output
|
||||
stream_output = chunk.get('output', {})
|
||||
if stream_output.get('text') is not None:
|
||||
pending_content += stream_output.get('text')
|
||||
# 获取模型传出的参考资料列表
|
||||
references_dict_list = stream_output.get('doc_references', [])
|
||||
|
||||
# 保存当前会话的session_id用于下次对话的语境
|
||||
query.session.using_conversation.uuid = stream_output.get('session_id')
|
||||
# 从模型传出的参考资料信息中提取用于替换的字典
|
||||
if references_dict_list is not None:
|
||||
for doc in references_dict_list:
|
||||
if doc.get('index_id') is not None:
|
||||
references_dict[doc.get('index_id')] = doc.get('doc_name')
|
||||
|
||||
# 获取模型传出的参考资料列表
|
||||
references_dict_list = stream_output.get('doc_references', [])
|
||||
# 将参考资料替换到文本中
|
||||
pending_content = self._replace_references(pending_content, references_dict)
|
||||
|
||||
# 从模型传出的参考资料信息中提取用于替换的字典
|
||||
if references_dict_list is not None:
|
||||
for doc in references_dict_list:
|
||||
if doc.get('index_id') is not None:
|
||||
references_dict[doc.get('index_id')] = doc.get('doc_name')
|
||||
if idx_chunk % 8 == 0 or is_final:
|
||||
yield llm_entities.MessageChunk(
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
is_final=is_final,
|
||||
)
|
||||
# 保存当前会话的session_id用于下次对话的语境
|
||||
query.session.using_conversation.uuid = stream_output.get('session_id')
|
||||
else:
|
||||
for chunk in response:
|
||||
if chunk.get('status_code') != 200:
|
||||
raise DashscopeAPIError(
|
||||
f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
|
||||
)
|
||||
if not chunk:
|
||||
continue
|
||||
idx_chunk += 1
|
||||
# 获取流式传输的output
|
||||
stream_output = chunk.get('output', {})
|
||||
stream_think = stream_output.get('thoughts', [])
|
||||
if stream_think[0].get('thought'):
|
||||
if not think_start:
|
||||
think_start = True
|
||||
pending_content += f"<think>\n{stream_think[0].get('thought')}"
|
||||
else:
|
||||
# 继续输出 reasoning_content
|
||||
pending_content += stream_think[0].get('thought')
|
||||
elif stream_think[0].get('thought') == "" and not think_end:
|
||||
think_end = True
|
||||
pending_content += "\n</think>\n"
|
||||
if stream_output.get('text') is not None:
|
||||
pending_content += stream_output.get('text')
|
||||
|
||||
# 将参考资料替换到文本中
|
||||
pending_content = self._replace_references(pending_content, references_dict)
|
||||
# 保存当前会话的session_id用于下次对话的语境
|
||||
query.session.using_conversation.uuid = stream_output.get('session_id')
|
||||
|
||||
yield llm_entities.Message(
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
)
|
||||
# 获取模型传出的参考资料列表
|
||||
references_dict_list = stream_output.get('doc_references', [])
|
||||
|
||||
# 从模型传出的参考资料信息中提取用于替换的字典
|
||||
if references_dict_list is not None:
|
||||
for doc in references_dict_list:
|
||||
if doc.get('index_id') is not None:
|
||||
references_dict[doc.get('index_id')] = doc.get('doc_name')
|
||||
|
||||
# 将参考资料替换到文本中
|
||||
pending_content = self._replace_references(pending_content, references_dict)
|
||||
|
||||
yield llm_entities.Message(
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
)
|
||||
|
||||
async def _workflow_messages(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""Dashscope 工作流对话请求"""
|
||||
@@ -171,52 +243,109 @@ class DashScopeAPIRunner(runner.RequestRunner):
|
||||
incremental_output=True, # 增量输出,使用流式输出需要开启增量输出
|
||||
session_id=query.session.using_conversation.uuid, # 会话ID用于,多轮对话
|
||||
biz_params=biz_params, # 工作流应用的自定义输入参数传递
|
||||
flow_stream_mode="message_format" # 消息模式,输出/结束节点的流式结果
|
||||
# rag_options={ # 主要用于文件交互,暂不支持
|
||||
# "session_file_ids": ["FILE_ID1"], # FILE_ID1 替换为实际的临时文件ID,逗号隔开多个
|
||||
# }
|
||||
)
|
||||
|
||||
# 处理API返回的流式输出
|
||||
for chunk in response:
|
||||
if chunk.get('status_code') != 200:
|
||||
raise DashscopeAPIError(
|
||||
f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
|
||||
)
|
||||
if not chunk:
|
||||
continue
|
||||
try:
|
||||
# print(await query.adapter.is_stream_output_supported())
|
||||
is_stream = await query.adapter.is_stream_output_supported()
|
||||
|
||||
# 获取流式传输的output
|
||||
stream_output = chunk.get('output', {})
|
||||
if stream_output.get('text') is not None:
|
||||
pending_content += stream_output.get('text')
|
||||
except AttributeError:
|
||||
is_stream = False
|
||||
idx_chunk = 0
|
||||
if is_stream:
|
||||
for chunk in response:
|
||||
if chunk.get('status_code') != 200:
|
||||
raise DashscopeAPIError(
|
||||
f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
|
||||
)
|
||||
if not chunk:
|
||||
continue
|
||||
idx_chunk += 1
|
||||
# 获取流式传输的output
|
||||
stream_output = chunk.get('output', {})
|
||||
if stream_output.get('workflow_message') is not None:
|
||||
pending_content += stream_output.get('workflow_message').get('message').get('content')
|
||||
# if stream_output.get('text') is not None:
|
||||
# pending_content += stream_output.get('text')
|
||||
|
||||
# 保存当前会话的session_id用于下次对话的语境
|
||||
query.session.using_conversation.uuid = stream_output.get('session_id')
|
||||
is_final = False if stream_output.get('finish_reason', False) == 'null' else True
|
||||
|
||||
# 获取模型传出的参考资料列表
|
||||
references_dict_list = stream_output.get('doc_references', [])
|
||||
# 获取模型传出的参考资料列表
|
||||
references_dict_list = stream_output.get('doc_references', [])
|
||||
|
||||
# 从模型传出的参考资料信息中提取用于替换的字典
|
||||
if references_dict_list is not None:
|
||||
for doc in references_dict_list:
|
||||
if doc.get('index_id') is not None:
|
||||
references_dict[doc.get('index_id')] = doc.get('doc_name')
|
||||
# 从模型传出的参考资料信息中提取用于替换的字典
|
||||
if references_dict_list is not None:
|
||||
for doc in references_dict_list:
|
||||
if doc.get('index_id') is not None:
|
||||
references_dict[doc.get('index_id')] = doc.get('doc_name')
|
||||
|
||||
# 将参考资料替换到文本中
|
||||
pending_content = self._replace_references(pending_content, references_dict)
|
||||
# 将参考资料替换到文本中
|
||||
pending_content = self._replace_references(pending_content, references_dict)
|
||||
if idx_chunk % 8 == 0 or is_final:
|
||||
yield llm_entities.MessageChunk(
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
yield llm_entities.Message(
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
)
|
||||
# 保存当前会话的session_id用于下次对话的语境
|
||||
query.session.using_conversation.uuid = stream_output.get('session_id')
|
||||
|
||||
else:
|
||||
for chunk in response:
|
||||
if chunk.get('status_code') != 200:
|
||||
raise DashscopeAPIError(
|
||||
f'Dashscope API 请求失败: status_code={chunk.get("status_code")} message={chunk.get("message")} request_id={chunk.get("request_id")} '
|
||||
)
|
||||
if not chunk:
|
||||
continue
|
||||
|
||||
# 获取流式传输的output
|
||||
stream_output = chunk.get('output', {})
|
||||
if stream_output.get('text') is not None:
|
||||
pending_content += stream_output.get('text')
|
||||
|
||||
is_final = False if stream_output.get('finish_reason', False) == 'null' else True
|
||||
|
||||
# 保存当前会话的session_id用于下次对话的语境
|
||||
query.session.using_conversation.uuid = stream_output.get('session_id')
|
||||
|
||||
# 获取模型传出的参考资料列表
|
||||
references_dict_list = stream_output.get('doc_references', [])
|
||||
|
||||
# 从模型传出的参考资料信息中提取用于替换的字典
|
||||
if references_dict_list is not None:
|
||||
for doc in references_dict_list:
|
||||
if doc.get('index_id') is not None:
|
||||
references_dict[doc.get('index_id')] = doc.get('doc_name')
|
||||
|
||||
# 将参考资料替换到文本中
|
||||
pending_content = self._replace_references(pending_content, references_dict)
|
||||
|
||||
yield llm_entities.Message(
|
||||
role='assistant',
|
||||
content=pending_content,
|
||||
)
|
||||
|
||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行"""
|
||||
msg_seq = 0
|
||||
if self.app_type == 'agent':
|
||||
async for msg in self._agent_messages(query):
|
||||
if isinstance(msg, llm_entities.MessageChunk):
|
||||
msg_seq += 1
|
||||
msg.msg_sequence = msg_seq
|
||||
yield msg
|
||||
elif self.app_type == 'workflow':
|
||||
async for msg in self._workflow_messages(query):
|
||||
if isinstance(msg, llm_entities.MessageChunk):
|
||||
msg_seq += 1
|
||||
msg.msg_sequence = msg_seq
|
||||
yield msg
|
||||
else:
|
||||
raise DashscopeAPIError(f'不支持的 Dashscope 应用类型: {self.app_type}')
|
||||
|
||||
@@ -62,6 +62,39 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL)
|
||||
return f'<think>{thinking_text.group(1)}</think>\n{content_text}'
|
||||
|
||||
def _process_thinking_content(
|
||||
self,
|
||||
content: str,
|
||||
) -> tuple[str, str]:
|
||||
"""处理思维链内容
|
||||
|
||||
Args:
|
||||
content: 原始内容
|
||||
Returns:
|
||||
(处理后的内容, 提取的思维链内容)
|
||||
"""
|
||||
remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think')
|
||||
thinking_content = ''
|
||||
# 从 content 中提取 <think> 标签内容
|
||||
if content and '<think>' in content and '</think>' in content:
|
||||
import re
|
||||
|
||||
think_pattern = r'<think>(.*?)</think>'
|
||||
think_matches = re.findall(think_pattern, content, re.DOTALL)
|
||||
if think_matches:
|
||||
thinking_content = '\n'.join(think_matches)
|
||||
# 移除 content 中的 <think> 标签
|
||||
content = re.sub(think_pattern, '', content, flags=re.DOTALL).strip()
|
||||
|
||||
# 3. 根据 remove_think 参数决定是否保留思维链
|
||||
if remove_think:
|
||||
return content, ''
|
||||
else:
|
||||
# 如果有思维链内容,将其以 <think> 格式添加到 content 开头
|
||||
if thinking_content:
|
||||
content = f'<think>\n{thinking_content}\n</think>\n{content}'.strip()
|
||||
return content, thinking_content
|
||||
|
||||
async def _preprocess_user_message(self, query: core_entities.Query) -> tuple[str, list[str]]:
|
||||
"""预处理用户消息,提取纯文本,并将图片上传到 Dify 服务
|
||||
|
||||
@@ -108,13 +141,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
|
||||
mode = 'basic' # 标记是基础编排还是工作流编排
|
||||
|
||||
stream_output_pending_chunk = ''
|
||||
|
||||
batch_pending_max_size = self.pipeline_config['ai']['dify-service-api'].get(
|
||||
'output-batch-size', 0
|
||||
) # 积累一定量的消息更新消息一次
|
||||
|
||||
batch_pending_index = 0
|
||||
basic_mode_pending_chunk = ''
|
||||
|
||||
inputs = {}
|
||||
|
||||
@@ -132,52 +159,28 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
):
|
||||
self.ap.logger.debug('dify-chat-chunk: ' + str(chunk))
|
||||
|
||||
# 查询异常情况
|
||||
if chunk['event'] == 'error':
|
||||
yield llm_entities.Message(
|
||||
role='assistant',
|
||||
content=f"查询异常: [{chunk['code']}]. {chunk['message']}.\n请重试,如果还报错,请用 <font color='red'>**!reset**</font> 命令重置对话再尝试。",
|
||||
)
|
||||
|
||||
if chunk['event'] == 'workflow_started':
|
||||
mode = 'workflow'
|
||||
|
||||
if mode == 'workflow':
|
||||
if chunk['event'] == 'node_finished':
|
||||
if chunk['data']['node_type'] == 'answer':
|
||||
content, _ = self._process_thinking_content(chunk['data']['outputs']['answer'])
|
||||
|
||||
yield llm_entities.Message(
|
||||
role='assistant',
|
||||
content=self._try_convert_thinking(chunk['data']['outputs']['answer']),
|
||||
content=content,
|
||||
)
|
||||
elif chunk['event'] == 'message':
|
||||
stream_output_pending_chunk += chunk['answer']
|
||||
if self.pipeline_config['ai']['dify-service-api'].get('enable-streaming', False):
|
||||
# 消息数超过量就输出,从而达到streaming的效果
|
||||
batch_pending_index += 1
|
||||
if batch_pending_index >= batch_pending_max_size:
|
||||
yield llm_entities.Message(
|
||||
role='assistant',
|
||||
content=self._try_convert_thinking(stream_output_pending_chunk),
|
||||
)
|
||||
batch_pending_index = 0
|
||||
elif mode == 'basic':
|
||||
if chunk['event'] == 'message':
|
||||
stream_output_pending_chunk += chunk['answer']
|
||||
if self.pipeline_config['ai']['dify-service-api'].get('enable-streaming', False):
|
||||
# 消息数超过量就输出,从而达到streaming的效果
|
||||
batch_pending_index += 1
|
||||
if batch_pending_index >= batch_pending_max_size:
|
||||
yield llm_entities.Message(
|
||||
role='assistant',
|
||||
content=self._try_convert_thinking(stream_output_pending_chunk),
|
||||
)
|
||||
batch_pending_index = 0
|
||||
basic_mode_pending_chunk += chunk['answer']
|
||||
elif chunk['event'] == 'message_end':
|
||||
content, _ = self._process_thinking_content(basic_mode_pending_chunk)
|
||||
yield llm_entities.Message(
|
||||
role='assistant',
|
||||
content=self._try_convert_thinking(stream_output_pending_chunk),
|
||||
content=content,
|
||||
)
|
||||
stream_output_pending_chunk = ''
|
||||
basic_mode_pending_chunk = ''
|
||||
|
||||
if chunk is None:
|
||||
raise errors.DifyAPIError('Dify API 没有返回任何响应,请检查网络连接和API配置')
|
||||
@@ -226,14 +229,15 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
if chunk['event'] in ignored_events:
|
||||
continue
|
||||
|
||||
if chunk['event'] == 'agent_message':
|
||||
if chunk['event'] == 'agent_message' or chunk['event'] == 'message':
|
||||
pending_agent_message += chunk['answer']
|
||||
else:
|
||||
if pending_agent_message.strip() != '':
|
||||
pending_agent_message = pending_agent_message.replace('</details>Action:', '</details>')
|
||||
content, _ = self._process_thinking_content(pending_agent_message)
|
||||
yield llm_entities.Message(
|
||||
role='assistant',
|
||||
content=self._try_convert_thinking(pending_agent_message),
|
||||
content=content,
|
||||
)
|
||||
pending_agent_message = ''
|
||||
|
||||
@@ -341,26 +345,353 @@ class DifyServiceAPIRunner(runner.RequestRunner):
|
||||
elif chunk['event'] == 'workflow_finished':
|
||||
if chunk['data']['error']:
|
||||
raise errors.DifyAPIError(chunk['data']['error'])
|
||||
content, _ = self._process_thinking_content(chunk['data']['outputs']['summary'])
|
||||
|
||||
msg = llm_entities.Message(
|
||||
role='assistant',
|
||||
content=chunk['data']['outputs']['summary'],
|
||||
content=content,
|
||||
)
|
||||
|
||||
yield msg
|
||||
|
||||
|
||||
async def _chat_messages_chunk(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.MessageChunk, None]:
|
||||
"""调用聊天助手"""
|
||||
cov_id = query.session.using_conversation.uuid or ''
|
||||
query.variables['conversation_id'] = cov_id
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
|
||||
files = [
|
||||
{
|
||||
'type': 'image',
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': image_id,
|
||||
}
|
||||
for image_id in image_ids
|
||||
]
|
||||
|
||||
mode = 'basic' # 标记是基础编排还是工作流编排
|
||||
|
||||
basic_mode_pending_chunk = ''
|
||||
|
||||
inputs = {}
|
||||
|
||||
inputs.update(query.variables)
|
||||
message_idx = 0
|
||||
|
||||
chunk = None # 初始化chunk变量,防止在没有响应时引用错误
|
||||
|
||||
is_final = False
|
||||
think_start = False
|
||||
think_end = False
|
||||
|
||||
remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think')
|
||||
|
||||
async for chunk in self.dify_client.chat_messages(
|
||||
inputs=inputs,
|
||||
query=plain_text,
|
||||
user=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
conversation_id=cov_id,
|
||||
files=files,
|
||||
timeout=120,
|
||||
):
|
||||
self.ap.logger.debug('dify-chat-chunk: ' + str(chunk))
|
||||
|
||||
# if chunk['event'] == 'workflow_started':
|
||||
# mode = 'workflow'
|
||||
# if mode == 'workflow':
|
||||
# elif mode == 'basic':
|
||||
# 因为都只是返回的 message也没有工具调用什么的,暂时不分类
|
||||
if chunk['event'] == 'message':
|
||||
message_idx += 1
|
||||
if remove_think:
|
||||
if '<think>' in chunk['answer'] and not think_start:
|
||||
think_start = True
|
||||
continue
|
||||
if '</think>' in chunk['answer'] and not think_end:
|
||||
import re
|
||||
content = re.sub(r'^\n</think>', '', chunk['answer'])
|
||||
basic_mode_pending_chunk += content
|
||||
think_end = True
|
||||
elif think_end:
|
||||
basic_mode_pending_chunk += chunk['answer']
|
||||
if think_start:
|
||||
continue
|
||||
|
||||
else:
|
||||
basic_mode_pending_chunk += chunk['answer']
|
||||
|
||||
if chunk['event'] == 'message_end':
|
||||
is_final = True
|
||||
|
||||
if is_final or message_idx % 8 == 0:
|
||||
# content, _ = self._process_thinking_content(basic_mode_pending_chunk)
|
||||
yield llm_entities.MessageChunk(
|
||||
role='assistant',
|
||||
content=basic_mode_pending_chunk,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
|
||||
if chunk is None:
|
||||
raise errors.DifyAPIError('Dify API 没有返回任何响应,请检查网络连接和API配置')
|
||||
|
||||
query.session.using_conversation.uuid = chunk['conversation_id']
|
||||
|
||||
|
||||
async def _agent_chat_messages_chunk(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""调用聊天助手"""
|
||||
cov_id = query.session.using_conversation.uuid or ''
|
||||
query.variables['conversation_id'] = cov_id
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
|
||||
files = [
|
||||
{
|
||||
'type': 'image',
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': image_id,
|
||||
}
|
||||
for image_id in image_ids
|
||||
]
|
||||
|
||||
ignored_events = []
|
||||
|
||||
inputs = {}
|
||||
|
||||
inputs.update(query.variables)
|
||||
|
||||
pending_agent_message = ''
|
||||
|
||||
chunk = None # 初始化chunk变量,防止在没有响应时引用错误
|
||||
message_idx = 0
|
||||
is_final = False
|
||||
think_start = False
|
||||
think_end = False
|
||||
|
||||
remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think')
|
||||
|
||||
async for chunk in self.dify_client.chat_messages(
|
||||
inputs=inputs,
|
||||
query=plain_text,
|
||||
user=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
response_mode='streaming',
|
||||
conversation_id=cov_id,
|
||||
files=files,
|
||||
timeout=120,
|
||||
):
|
||||
self.ap.logger.debug('dify-agent-chunk: ' + str(chunk))
|
||||
|
||||
if chunk['event'] in ignored_events:
|
||||
continue
|
||||
|
||||
if chunk['event'] == 'agent_message':
|
||||
message_idx += 1
|
||||
if remove_think:
|
||||
if '<think>' in chunk['answer'] and not think_start:
|
||||
think_start = True
|
||||
continue
|
||||
if '</think>' in chunk['answer'] and not think_end:
|
||||
import re
|
||||
content = re.sub(r'^\n</think>', '', chunk['answer'])
|
||||
pending_agent_message += content
|
||||
think_end = True
|
||||
elif think_end:
|
||||
pending_agent_message += chunk['answer']
|
||||
if think_start:
|
||||
continue
|
||||
|
||||
else:
|
||||
pending_agent_message += chunk['answer']
|
||||
elif chunk['event'] == 'message_end':
|
||||
is_final = True
|
||||
else:
|
||||
|
||||
if chunk['event'] == 'agent_thought':
|
||||
if chunk['tool'] != '' and chunk['observation'] != '': # 工具调用结果,跳过
|
||||
continue
|
||||
message_idx += 1
|
||||
if chunk['tool']:
|
||||
msg = llm_entities.MessageChunk(
|
||||
role='assistant',
|
||||
tool_calls=[
|
||||
llm_entities.ToolCall(
|
||||
id=chunk['id'],
|
||||
type='function',
|
||||
function=llm_entities.FunctionCall(
|
||||
name=chunk['tool'],
|
||||
arguments=json.dumps({}),
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
yield msg
|
||||
if chunk['event'] == 'message_file':
|
||||
message_idx += 1
|
||||
if chunk['type'] == 'image' and chunk['belongs_to'] == 'assistant':
|
||||
base_url = self.dify_client.base_url
|
||||
|
||||
if base_url.endswith('/v1'):
|
||||
base_url = base_url[:-3]
|
||||
|
||||
image_url = base_url + chunk['url']
|
||||
|
||||
yield llm_entities.MessageChunk(
|
||||
role='assistant',
|
||||
content=[llm_entities.ContentElement.from_image_url(image_url)],
|
||||
is_final=is_final,
|
||||
|
||||
)
|
||||
|
||||
if chunk['event'] == 'error':
|
||||
raise errors.DifyAPIError('dify 服务错误: ' + chunk['message'])
|
||||
if message_idx % 8 == 0 or is_final:
|
||||
yield llm_entities.MessageChunk(
|
||||
role='assistant',
|
||||
content=pending_agent_message,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
if chunk is None:
|
||||
raise errors.DifyAPIError('Dify API 没有返回任何响应,请检查网络连接和API配置')
|
||||
|
||||
query.session.using_conversation.uuid = chunk['conversation_id']
|
||||
|
||||
async def _workflow_messages_chunk(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""调用工作流"""
|
||||
|
||||
if not query.session.using_conversation.uuid:
|
||||
query.session.using_conversation.uuid = str(uuid.uuid4())
|
||||
|
||||
query.variables['conversation_id'] = query.session.using_conversation.uuid
|
||||
|
||||
plain_text, image_ids = await self._preprocess_user_message(query)
|
||||
|
||||
files = [
|
||||
{
|
||||
'type': 'image',
|
||||
'transfer_method': 'local_file',
|
||||
'upload_file_id': image_id,
|
||||
}
|
||||
for image_id in image_ids
|
||||
]
|
||||
|
||||
ignored_events = ['workflow_started']
|
||||
|
||||
inputs = { # these variables are legacy variables, we need to keep them for compatibility
|
||||
'langbot_user_message_text': plain_text,
|
||||
'langbot_session_id': query.variables['session_id'],
|
||||
'langbot_conversation_id': query.variables['conversation_id'],
|
||||
'langbot_msg_create_time': query.variables['msg_create_time'],
|
||||
}
|
||||
|
||||
inputs.update(query.variables)
|
||||
messsage_idx = 0
|
||||
is_final = False
|
||||
think_start = False
|
||||
think_end = False
|
||||
workflow_contents = ''
|
||||
|
||||
remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think')
|
||||
async for chunk in self.dify_client.workflow_run(
|
||||
inputs=inputs,
|
||||
user=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
|
||||
files=files,
|
||||
timeout=120,
|
||||
):
|
||||
self.ap.logger.debug('dify-workflow-chunk: ' + str(chunk))
|
||||
if chunk['event'] in ignored_events:
|
||||
continue
|
||||
if chunk['event'] == 'workflow_finished':
|
||||
is_final = True
|
||||
if chunk['data']['error']:
|
||||
raise errors.DifyAPIError(chunk['data']['error'])
|
||||
|
||||
if chunk['event'] == 'text_chunk':
|
||||
messsage_idx += 1
|
||||
if remove_think:
|
||||
if '<think>' in chunk['data']['text'] and not think_start:
|
||||
think_start = True
|
||||
continue
|
||||
if '</think>' in chunk['data']['text'] and not think_end:
|
||||
import re
|
||||
content = re.sub(r'^\n</think>', '', chunk['data']['text'])
|
||||
workflow_contents += content
|
||||
think_end = True
|
||||
elif think_end:
|
||||
workflow_contents += chunk['data']['text']
|
||||
if think_start:
|
||||
continue
|
||||
|
||||
else:
|
||||
workflow_contents += chunk['data']['text']
|
||||
|
||||
if chunk['event'] == 'node_started':
|
||||
if chunk['data']['node_type'] == 'start' or chunk['data']['node_type'] == 'end':
|
||||
continue
|
||||
messsage_idx += 1
|
||||
msg = llm_entities.MessageChunk(
|
||||
role='assistant',
|
||||
content=None,
|
||||
tool_calls=[
|
||||
llm_entities.ToolCall(
|
||||
id=chunk['data']['node_id'],
|
||||
type='function',
|
||||
function=llm_entities.FunctionCall(
|
||||
name=chunk['data']['title'],
|
||||
arguments=json.dumps({}),
|
||||
),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
yield msg
|
||||
|
||||
|
||||
if messsage_idx % 8 == 0 or is_final:
|
||||
yield llm_entities.MessageChunk(
|
||||
role='assistant',
|
||||
content=workflow_contents,
|
||||
is_final=is_final,
|
||||
)
|
||||
|
||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
"""运行请求"""
|
||||
if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat':
|
||||
async for msg in self._chat_messages(query):
|
||||
yield msg
|
||||
elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'agent':
|
||||
async for msg in self._agent_chat_messages(query):
|
||||
yield msg
|
||||
elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'workflow':
|
||||
async for msg in self._workflow_messages(query):
|
||||
yield msg
|
||||
if await query.adapter.is_stream_output_supported():
|
||||
msg_idx = 0
|
||||
if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat':
|
||||
async for msg in self._chat_messages_chunk(query):
|
||||
msg_idx += 1
|
||||
msg.msg_sequence = msg_idx
|
||||
yield msg
|
||||
elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'agent':
|
||||
async for msg in self._agent_chat_messages_chunk(query):
|
||||
msg_idx += 1
|
||||
msg.msg_sequence = msg_idx
|
||||
yield msg
|
||||
elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'workflow':
|
||||
async for msg in self._workflow_messages_chunk(query):
|
||||
msg_idx += 1
|
||||
msg.msg_sequence = msg_idx
|
||||
yield msg
|
||||
else:
|
||||
raise errors.DifyAPIError(
|
||||
f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}'
|
||||
)
|
||||
else:
|
||||
raise errors.DifyAPIError(
|
||||
f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}'
|
||||
)
|
||||
if self.pipeline_config['ai']['dify-service-api']['app-type'] == 'chat':
|
||||
async for msg in self._chat_messages(query):
|
||||
yield msg
|
||||
elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'agent':
|
||||
async for msg in self._agent_chat_messages(query):
|
||||
yield msg
|
||||
elif self.pipeline_config['ai']['dify-service-api']['app-type'] == 'workflow':
|
||||
async for msg in self._workflow_messages(query):
|
||||
yield msg
|
||||
else:
|
||||
raise errors.DifyAPIError(
|
||||
f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}'
|
||||
)
|
||||
@@ -27,7 +27,16 @@ Respond in the same language as the user's input.
|
||||
class LocalAgentRunner(runner.RequestRunner):
|
||||
"""本地Agent请求运行器"""
|
||||
|
||||
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
|
||||
class ToolCallTracker:
|
||||
"""工具调用追踪器"""
|
||||
|
||||
def __init__(self):
|
||||
self.active_calls: dict[str, dict] = {}
|
||||
self.completed_calls: list[llm_entities.ToolCall] = []
|
||||
|
||||
async def run(
|
||||
self, query: core_entities.Query
|
||||
) -> typing.AsyncGenerator[llm_entities.Message | llm_entities.MessageChunk, None]:
|
||||
"""运行请求"""
|
||||
pending_tool_calls = []
|
||||
|
||||
@@ -80,20 +89,93 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
|
||||
req_messages = query.prompt.messages.copy() + query.messages.copy() + [user_message]
|
||||
|
||||
# 首次请求
|
||||
msg = await query.use_llm_model.requester.invoke_llm(
|
||||
query,
|
||||
query.use_llm_model,
|
||||
req_messages,
|
||||
query.use_funcs,
|
||||
extra_args=query.use_llm_model.model_entity.extra_args,
|
||||
)
|
||||
try:
|
||||
is_stream = await query.adapter.is_stream_output_supported()
|
||||
except AttributeError:
|
||||
is_stream = False
|
||||
|
||||
yield msg
|
||||
remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think')
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
if not is_stream:
|
||||
# 非流式输出,直接请求
|
||||
|
||||
req_messages.append(msg)
|
||||
msg = await query.use_llm_model.requester.invoke_llm(
|
||||
query,
|
||||
query.use_llm_model,
|
||||
req_messages,
|
||||
query.use_funcs,
|
||||
extra_args=query.use_llm_model.model_entity.extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
yield msg
|
||||
final_msg = msg
|
||||
else:
|
||||
# 流式输出,需要处理工具调用
|
||||
tool_calls_map: dict[str, llm_entities.ToolCall] = {}
|
||||
msg_idx = 0
|
||||
accumulated_content = '' # 从开始累积的所有内容
|
||||
last_role = 'assistant'
|
||||
msg_sequence = 1
|
||||
async for msg in query.use_llm_model.requester.invoke_llm_stream(
|
||||
query,
|
||||
query.use_llm_model,
|
||||
req_messages,
|
||||
query.use_funcs,
|
||||
extra_args=query.use_llm_model.model_entity.extra_args,
|
||||
remove_think=remove_think,
|
||||
):
|
||||
msg_idx = msg_idx + 1
|
||||
|
||||
# 记录角色
|
||||
if msg.role:
|
||||
last_role = msg.role
|
||||
|
||||
# 累积内容
|
||||
if msg.content:
|
||||
accumulated_content += msg.content
|
||||
|
||||
# 处理工具调用
|
||||
if msg.tool_calls:
|
||||
for tool_call in msg.tool_calls:
|
||||
if tool_call.id not in tool_calls_map:
|
||||
tool_calls_map[tool_call.id] = llm_entities.ToolCall(
|
||||
id=tool_call.id,
|
||||
type=tool_call.type,
|
||||
function=llm_entities.FunctionCall(
|
||||
name=tool_call.function.name if tool_call.function else '', arguments=''
|
||||
),
|
||||
)
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
# 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖
|
||||
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
|
||||
# print(list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None)
|
||||
# continue
|
||||
# 每8个chunk或最后一个chunk时,输出所有累积的内容
|
||||
if msg_idx % 8 == 0 or msg.is_final:
|
||||
msg_sequence += 1
|
||||
yield llm_entities.MessageChunk(
|
||||
role=last_role,
|
||||
content=accumulated_content, # 输出所有累积内容
|
||||
tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None,
|
||||
is_final=msg.is_final,
|
||||
msg_sequence=msg_sequence,
|
||||
)
|
||||
|
||||
# 创建最终消息用于后续处理
|
||||
final_msg = llm_entities.MessageChunk(
|
||||
role=last_role,
|
||||
content=accumulated_content,
|
||||
tool_calls=list(tool_calls_map.values()) if tool_calls_map else None,
|
||||
msg_sequence=msg_sequence,
|
||||
)
|
||||
|
||||
pending_tool_calls = final_msg.tool_calls
|
||||
first_content = final_msg.content
|
||||
if isinstance(final_msg, llm_entities.MessageChunk):
|
||||
|
||||
first_end_sequence = final_msg.msg_sequence
|
||||
|
||||
req_messages.append(final_msg)
|
||||
|
||||
# 持续请求,只要还有待处理的工具调用就继续处理调用
|
||||
while pending_tool_calls:
|
||||
@@ -104,12 +186,18 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
parameters = json.loads(func.arguments)
|
||||
|
||||
func_ret = await self.ap.tool_mgr.execute_func_call(query, func.name, parameters)
|
||||
|
||||
msg = llm_entities.Message(
|
||||
role='tool',
|
||||
content=json.dumps(func_ret, ensure_ascii=False),
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
if is_stream:
|
||||
msg = llm_entities.MessageChunk(
|
||||
role='tool',
|
||||
content=json.dumps(func_ret, ensure_ascii=False),
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
else:
|
||||
msg = llm_entities.Message(
|
||||
role='tool',
|
||||
content=json.dumps(func_ret, ensure_ascii=False),
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
|
||||
yield msg
|
||||
|
||||
@@ -122,17 +210,82 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
|
||||
req_messages.append(err_msg)
|
||||
|
||||
# 处理完所有调用,再次请求
|
||||
msg = await query.use_llm_model.requester.invoke_llm(
|
||||
query,
|
||||
query.use_llm_model,
|
||||
req_messages,
|
||||
query.use_funcs,
|
||||
extra_args=query.use_llm_model.model_entity.extra_args,
|
||||
)
|
||||
if is_stream:
|
||||
tool_calls_map = {}
|
||||
msg_idx = 0
|
||||
accumulated_content = '' # 从开始累积的所有内容
|
||||
last_role = 'assistant'
|
||||
msg_sequence = first_end_sequence
|
||||
|
||||
yield msg
|
||||
async for msg in query.use_llm_model.requester.invoke_llm_stream(
|
||||
query,
|
||||
query.use_llm_model,
|
||||
req_messages,
|
||||
query.use_funcs,
|
||||
extra_args=query.use_llm_model.model_entity.extra_args,
|
||||
remove_think=remove_think,
|
||||
):
|
||||
msg_idx += 1
|
||||
|
||||
pending_tool_calls = msg.tool_calls
|
||||
# 记录角色
|
||||
if msg.role:
|
||||
last_role = msg.role
|
||||
|
||||
req_messages.append(msg)
|
||||
# 第一次请求工具调用时的内容
|
||||
if msg_idx == 1:
|
||||
accumulated_content = first_content if first_content is not None else accumulated_content
|
||||
|
||||
# 累积内容
|
||||
if msg.content:
|
||||
accumulated_content += msg.content
|
||||
|
||||
# 处理工具调用
|
||||
if msg.tool_calls:
|
||||
for tool_call in msg.tool_calls:
|
||||
if tool_call.id not in tool_calls_map:
|
||||
tool_calls_map[tool_call.id] = llm_entities.ToolCall(
|
||||
id=tool_call.id,
|
||||
type=tool_call.type,
|
||||
function=llm_entities.FunctionCall(
|
||||
name=tool_call.function.name if tool_call.function else '', arguments=''
|
||||
),
|
||||
)
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
# 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖
|
||||
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
|
||||
|
||||
# 每8个chunk或最后一个chunk时,输出所有累积的内容
|
||||
if msg_idx % 8 == 0 or msg.is_final:
|
||||
msg_sequence += 1
|
||||
yield llm_entities.MessageChunk(
|
||||
role=last_role,
|
||||
content=accumulated_content, # 输出所有累积内容
|
||||
tool_calls=list(tool_calls_map.values()) if (tool_calls_map and msg.is_final) else None,
|
||||
is_final=msg.is_final,
|
||||
msg_sequence=msg_sequence,
|
||||
)
|
||||
|
||||
final_msg = llm_entities.MessageChunk(
|
||||
role=last_role,
|
||||
content=accumulated_content,
|
||||
tool_calls=list(tool_calls_map.values()) if tool_calls_map else None,
|
||||
msg_sequence=msg_sequence,
|
||||
|
||||
)
|
||||
else:
|
||||
# 处理完所有调用,再次请求
|
||||
msg = await query.use_llm_model.requester.invoke_llm(
|
||||
query,
|
||||
query.use_llm_model,
|
||||
req_messages,
|
||||
query.use_funcs,
|
||||
extra_args=query.use_llm_model.model_entity.extra_args,
|
||||
remove_think=remove_think,
|
||||
)
|
||||
|
||||
yield msg
|
||||
final_msg = msg
|
||||
|
||||
pending_tool_calls = final_msg.tool_calls
|
||||
|
||||
req_messages.append(final_msg)
|
||||
|
||||
@@ -204,9 +204,9 @@ async def get_slack_image_to_base64(pic_url: str, bot_token: str):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(pic_url, headers=headers) as resp:
|
||||
mime_type = resp.headers.get("Content-Type", "application/octet-stream")
|
||||
mime_type = resp.headers.get('Content-Type', 'application/octet-stream')
|
||||
file_bytes = await resp.read()
|
||||
base64_str = base64.b64encode(file_bytes).decode("utf-8")
|
||||
return f"data:{mime_type};base64,{base64_str}"
|
||||
base64_str = base64.b64encode(file_bytes).decode('utf-8')
|
||||
return f'data:{mime_type};base64,{base64_str}'
|
||||
except Exception as e:
|
||||
raise (e)
|
||||
raise (e)
|
||||
|
||||
@@ -32,7 +32,7 @@ def import_dir(path: str):
|
||||
rel_path = full_path.replace(os.path.dirname(os.path.dirname(os.path.dirname(__file__))), '')
|
||||
rel_path = rel_path[1:]
|
||||
rel_path = rel_path.replace('/', '.')[:-3]
|
||||
rel_path = rel_path.replace("\\",".")
|
||||
rel_path = rel_path.replace('\\', '.')
|
||||
importlib.import_module(rel_path)
|
||||
|
||||
|
||||
|
||||
@@ -87,7 +87,8 @@
|
||||
"hide-exception": true,
|
||||
"at-sender": true,
|
||||
"quote-origin": true,
|
||||
"track-function-calls": false
|
||||
"track-function-calls": false,
|
||||
"remove-think": true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -138,20 +138,7 @@ stages:
|
||||
label:
|
||||
en_US: Remove
|
||||
zh_Hans: 移除
|
||||
- name: enable-streaming
|
||||
label:
|
||||
en_US: enable streaming mode
|
||||
zh_Hans: 开启流式输出
|
||||
type: boolean
|
||||
required: true
|
||||
default: false
|
||||
- name: output-batch-size
|
||||
label:
|
||||
en_US: output batch size
|
||||
zh_Hans: 输出批次大小(积累多少条消息后一起输出)
|
||||
type: integer
|
||||
required: true
|
||||
default: 10
|
||||
|
||||
|
||||
- name: dashscope-app-api
|
||||
label:
|
||||
|
||||
@@ -105,3 +105,13 @@ stages:
|
||||
type: boolean
|
||||
required: true
|
||||
default: false
|
||||
- name: remove-think
|
||||
label:
|
||||
en_US: Remove CoT
|
||||
zh_Hans: 删除思维链
|
||||
description:
|
||||
en_US: If enabled, LangBot will remove the LLM thought content in response
|
||||
zh_Hans: 如果启用,将自动删除大模型回复中的模型思考内容
|
||||
type: boolean
|
||||
required: true
|
||||
default: true
|
||||
|
||||
@@ -21,6 +21,7 @@
|
||||
"dependencies": {
|
||||
"@dnd-kit/core": "^6.3.1",
|
||||
"@dnd-kit/sortable": "^10.0.0",
|
||||
"@dnd-kit/utilities": "^3.2.2",
|
||||
"@hookform/resolvers": "^5.0.1",
|
||||
"@radix-ui/react-checkbox": "^1.3.1",
|
||||
"@radix-ui/react-dialog": "^1.1.14",
|
||||
|
||||
@@ -298,6 +298,18 @@ export default function EmbeddingForm({
|
||||
|
||||
function testEmbeddingModelInForm() {
|
||||
setModelTesting(true);
|
||||
const extraArgsObj: Record<string, string | number | boolean> = {};
|
||||
form
|
||||
.getValues('extra_args')
|
||||
?.forEach((arg: { key: string; type: string; value: string }) => {
|
||||
if (arg.type === 'number') {
|
||||
extraArgsObj[arg.key] = Number(arg.value);
|
||||
} else if (arg.type === 'boolean') {
|
||||
extraArgsObj[arg.key] = arg.value === 'true';
|
||||
} else {
|
||||
extraArgsObj[arg.key] = arg.value;
|
||||
}
|
||||
});
|
||||
httpClient
|
||||
.testEmbeddingModel('_', {
|
||||
uuid: '',
|
||||
@@ -309,6 +321,7 @@ export default function EmbeddingForm({
|
||||
timeout: 120,
|
||||
},
|
||||
api_keys: [form.getValues('api_key')],
|
||||
extra_args: extraArgsObj,
|
||||
})
|
||||
.then((res) => {
|
||||
console.log(res);
|
||||
|
||||
@@ -312,6 +312,18 @@ export default function LLMForm({
|
||||
|
||||
function testLLMModelInForm() {
|
||||
setModelTesting(true);
|
||||
const extraArgsObj: Record<string, string | number | boolean> = {};
|
||||
form
|
||||
.getValues('extra_args')
|
||||
?.forEach((arg: { key: string; type: string; value: string }) => {
|
||||
if (arg.type === 'number') {
|
||||
extraArgsObj[arg.key] = Number(arg.value);
|
||||
} else if (arg.type === 'boolean') {
|
||||
extraArgsObj[arg.key] = arg.value === 'true';
|
||||
} else {
|
||||
extraArgsObj[arg.key] = arg.value;
|
||||
}
|
||||
});
|
||||
httpClient
|
||||
.testLLMModel('_', {
|
||||
uuid: '',
|
||||
@@ -324,7 +336,7 @@ export default function LLMForm({
|
||||
},
|
||||
api_keys: [form.getValues('api_key')],
|
||||
abilities: form.getValues('abilities'),
|
||||
extra_args: form.getValues('extra_args'),
|
||||
extra_args: extraArgsObj,
|
||||
})
|
||||
.then((res) => {
|
||||
console.log(res);
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useState, useEffect, useRef } from 'react';
|
||||
import React, { useState, useEffect, useRef, useCallback } from 'react';
|
||||
import { useTranslation } from 'react-i18next';
|
||||
import { httpClient } from '@/app/infra/http/HttpClient';
|
||||
import { DialogContent } from '@/components/ui/dialog';
|
||||
@@ -10,6 +10,7 @@ import { cn } from '@/lib/utils';
|
||||
import { Message } from '@/app/infra/entities/message';
|
||||
import { toast } from 'sonner';
|
||||
import AtBadge from './AtBadge';
|
||||
import { Switch } from '@/components/ui/switch';
|
||||
|
||||
interface MessageComponent {
|
||||
type: 'At' | 'Plain';
|
||||
@@ -36,17 +37,44 @@ export default function DebugDialog({
|
||||
const [showAtPopover, setShowAtPopover] = useState(false);
|
||||
const [hasAt, setHasAt] = useState(false);
|
||||
const [isHovering, setIsHovering] = useState(false);
|
||||
const [isStreaming, setIsStreaming] = useState(true);
|
||||
const messagesEndRef = useRef<HTMLDivElement>(null);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
const popoverRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const scrollToBottom = () => {
|
||||
messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' });
|
||||
};
|
||||
const scrollToBottom = useCallback(() => {
|
||||
// 使用setTimeout确保在DOM更新后执行滚动
|
||||
setTimeout(() => {
|
||||
const scrollArea = document.querySelector('.scroll-area') as HTMLElement;
|
||||
if (scrollArea) {
|
||||
scrollArea.scrollTo({
|
||||
top: scrollArea.scrollHeight,
|
||||
behavior: 'smooth',
|
||||
});
|
||||
}
|
||||
// 同时确保messagesEndRef也滚动到视图
|
||||
messagesEndRef.current?.scrollIntoView({ behavior: 'smooth' });
|
||||
}, 0);
|
||||
}, []);
|
||||
|
||||
const loadMessages = useCallback(
|
||||
async (pipelineId: string) => {
|
||||
try {
|
||||
const response = await httpClient.getWebChatHistoryMessages(
|
||||
pipelineId,
|
||||
sessionType,
|
||||
);
|
||||
setMessages(response.messages);
|
||||
} catch (error) {
|
||||
console.error('Failed to load messages:', error);
|
||||
}
|
||||
},
|
||||
[sessionType],
|
||||
);
|
||||
// 在useEffect中监听messages变化时滚动
|
||||
useEffect(() => {
|
||||
scrollToBottom();
|
||||
}, [messages]);
|
||||
}, [messages, scrollToBottom]);
|
||||
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
@@ -59,7 +87,7 @@ export default function DebugDialog({
|
||||
if (open) {
|
||||
loadMessages(selectedPipelineId);
|
||||
}
|
||||
}, [sessionType, selectedPipelineId]);
|
||||
}, [sessionType, selectedPipelineId, open, loadMessages]);
|
||||
|
||||
useEffect(() => {
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
@@ -84,18 +112,6 @@ export default function DebugDialog({
|
||||
}
|
||||
}, [showAtPopover]);
|
||||
|
||||
const loadMessages = async (pipelineId: string) => {
|
||||
try {
|
||||
const response = await httpClient.getWebChatHistoryMessages(
|
||||
pipelineId,
|
||||
sessionType,
|
||||
);
|
||||
setMessages(response.messages);
|
||||
} catch (error) {
|
||||
console.error('Failed to load messages:', error);
|
||||
}
|
||||
};
|
||||
|
||||
const handleInputChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const value = e.target.value;
|
||||
if (sessionType === 'group') {
|
||||
@@ -165,19 +181,87 @@ export default function DebugDialog({
|
||||
timestamp: new Date().toISOString(),
|
||||
message_chain: messageChain,
|
||||
};
|
||||
// 根据isStreaming状态决定使用哪种传输方式
|
||||
if (isStreaming) {
|
||||
// streaming
|
||||
// 创建初始bot消息
|
||||
const placeholderRandomId = Math.floor(Math.random() * 1000000);
|
||||
const botMessagePlaceholder: Message = {
|
||||
id: placeholderRandomId,
|
||||
role: 'assistant',
|
||||
content: 'Generating...',
|
||||
timestamp: new Date().toISOString(),
|
||||
message_chain: [{ type: 'Plain', text: 'Generating...' }],
|
||||
};
|
||||
|
||||
setMessages((prevMessages) => [...prevMessages, userMessage]);
|
||||
setInputValue('');
|
||||
setHasAt(false);
|
||||
// 添加用户消息和初始bot消息到状态
|
||||
|
||||
const response = await httpClient.sendWebChatMessage(
|
||||
sessionType,
|
||||
messageChain,
|
||||
selectedPipelineId,
|
||||
120000,
|
||||
);
|
||||
setMessages((prevMessages) => [
|
||||
...prevMessages,
|
||||
userMessage,
|
||||
botMessagePlaceholder,
|
||||
]);
|
||||
setInputValue('');
|
||||
setHasAt(false);
|
||||
try {
|
||||
await httpClient.sendStreamingWebChatMessage(
|
||||
sessionType,
|
||||
messageChain,
|
||||
selectedPipelineId,
|
||||
(data) => {
|
||||
// 处理流式响应数据
|
||||
console.log('data', data);
|
||||
if (data.message) {
|
||||
// 更新完整内容
|
||||
|
||||
setMessages((prevMessages) => [...prevMessages, response.message]);
|
||||
setMessages((prevMessages) => {
|
||||
const updatedMessages = [...prevMessages];
|
||||
const botMessageIndex = updatedMessages.findIndex(
|
||||
(message) => message.id === placeholderRandomId,
|
||||
);
|
||||
if (botMessageIndex !== -1) {
|
||||
updatedMessages[botMessageIndex] = {
|
||||
...updatedMessages[botMessageIndex],
|
||||
content: data.message.content,
|
||||
message_chain: [
|
||||
{ type: 'Plain', text: data.message.content },
|
||||
],
|
||||
};
|
||||
}
|
||||
return updatedMessages;
|
||||
});
|
||||
}
|
||||
},
|
||||
() => {},
|
||||
(error) => {
|
||||
// 处理错误
|
||||
console.error('Streaming error:', error);
|
||||
if (sessionType === 'person') {
|
||||
toast.error(t('pipelines.debugDialog.sendFailed'));
|
||||
}
|
||||
},
|
||||
);
|
||||
} catch (error) {
|
||||
console.error('Failed to send streaming message:', error);
|
||||
if (sessionType === 'person') {
|
||||
toast.error(t('pipelines.debugDialog.sendFailed'));
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// non-streaming
|
||||
setMessages((prevMessages) => [...prevMessages, userMessage]);
|
||||
setInputValue('');
|
||||
setHasAt(false);
|
||||
|
||||
const response = await httpClient.sendWebChatMessage(
|
||||
sessionType,
|
||||
messageChain,
|
||||
selectedPipelineId,
|
||||
180000,
|
||||
);
|
||||
|
||||
setMessages((prevMessages) => [...prevMessages, response.message]);
|
||||
}
|
||||
} catch (
|
||||
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
||||
error: any
|
||||
@@ -306,6 +390,12 @@ export default function DebugDialog({
|
||||
</ScrollArea>
|
||||
|
||||
<div className="p-4 pb-0 bg-white flex gap-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<span className="text-sm text-gray-600">
|
||||
{t('pipelines.debugDialog.streaming')}
|
||||
</span>
|
||||
<Switch checked={isStreaming} onCheckedChange={setIsStreaming} />
|
||||
</div>
|
||||
<div className="flex-1 flex items-center gap-2">
|
||||
{hasAt && (
|
||||
<AtBadge targetName="webchatbot" onRemove={handleAtRemove} />
|
||||
|
||||
@@ -372,6 +372,99 @@ class HttpClient {
|
||||
);
|
||||
}
|
||||
|
||||
public async sendStreamingWebChatMessage(
|
||||
sessionType: string,
|
||||
messageChain: object[],
|
||||
pipelineId: string,
|
||||
onMessage: (data: ApiRespWebChatMessage) => void,
|
||||
onComplete: () => void,
|
||||
onError: (error: Error) => void,
|
||||
): Promise<void> {
|
||||
try {
|
||||
// 构造完整的URL,处理相对路径的情况
|
||||
let url = `${this.baseURL}/api/v1/pipelines/${pipelineId}/chat/send`;
|
||||
if (this.baseURL === '/') {
|
||||
// 获取用户访问的完整URL
|
||||
const baseURL = window.location.origin;
|
||||
url = `${baseURL}/api/v1/pipelines/${pipelineId}/chat/send`;
|
||||
}
|
||||
|
||||
// 使用fetch发送流式请求,因为axios在浏览器环境中不直接支持流式响应
|
||||
const response = await fetch(url, {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
Authorization: `Bearer ${this.getSessionSync()}`,
|
||||
},
|
||||
body: JSON.stringify({
|
||||
session_type: sessionType,
|
||||
message: messageChain,
|
||||
is_stream: true,
|
||||
}),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
|
||||
if (!response.body) {
|
||||
throw new Error('ReadableStream not supported');
|
||||
}
|
||||
|
||||
const reader = response.body.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = '';
|
||||
|
||||
// 读取流式响应
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read();
|
||||
|
||||
if (done) {
|
||||
onComplete();
|
||||
break;
|
||||
}
|
||||
|
||||
// 解码数据
|
||||
buffer += decoder.decode(value, { stream: true });
|
||||
|
||||
// 处理完整的JSON对象
|
||||
const lines = buffer.split('\n\n');
|
||||
buffer = lines.pop() || '';
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data:')) {
|
||||
try {
|
||||
const data = JSON.parse(line.slice(5));
|
||||
|
||||
if (data.type === 'end') {
|
||||
// 流传输结束
|
||||
reader.cancel();
|
||||
onComplete();
|
||||
return;
|
||||
}
|
||||
if (data.type === 'start') {
|
||||
console.log(data.type);
|
||||
}
|
||||
|
||||
if (data.message) {
|
||||
// 处理消息数据
|
||||
onMessage(data);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error('Error parsing streaming data:', error);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock();
|
||||
}
|
||||
} catch (error) {
|
||||
onError(error as Error);
|
||||
}
|
||||
}
|
||||
|
||||
public getWebChatHistoryMessages(
|
||||
pipelineId: string,
|
||||
sessionType: string,
|
||||
|
||||
@@ -233,6 +233,7 @@ const enUS = {
|
||||
loadMessagesFailed: 'Failed to load messages',
|
||||
loadPipelinesFailed: 'Failed to load pipelines',
|
||||
atTips: 'Mention the bot',
|
||||
streaming: 'Streaming',
|
||||
},
|
||||
},
|
||||
knowledge: {
|
||||
|
||||
@@ -235,6 +235,7 @@ const jaJP = {
|
||||
loadMessagesFailed: 'メッセージの読み込みに失敗しました',
|
||||
loadPipelinesFailed: 'パイプラインの読み込みに失敗しました',
|
||||
atTips: 'ボットをメンション',
|
||||
streaming: 'ストリーミング',
|
||||
},
|
||||
},
|
||||
knowledge: {
|
||||
|
||||
@@ -228,6 +228,7 @@ const zhHans = {
|
||||
loadMessagesFailed: '加载消息失败',
|
||||
loadPipelinesFailed: '加载流水线失败',
|
||||
atTips: '提及机器人',
|
||||
streaming: '流式传输',
|
||||
},
|
||||
},
|
||||
knowledge: {
|
||||
|
||||
Reference in New Issue
Block a user