fix:del some print ,and amend respback on stream judge ,and del in dingtalk this is_stream_output_supported() use

This commit is contained in:
Dong_master
2025-07-29 23:09:02 +08:00
committed by Junyan Qin
parent 074d359c8e
commit a9776b7b53
10 changed files with 127 additions and 186 deletions

View File

@@ -39,11 +39,9 @@ class SendResponseBackStage(stage.PipelineStage):
quote_origin = query.pipeline_config['output']['misc']['quote-origin'] quote_origin = query.pipeline_config['output']['misc']['quote-origin']
has_chunks = any(isinstance(msg, llm_entities.MessageChunk) for msg in query.resp_messages) # has_chunks = any(isinstance(msg, llm_entities.MessageChunk) for msg in query.resp_messages)
print(has_chunks) if await query.adapter.is_stream_output_supported():
if has_chunks and hasattr(query.adapter,'reply_message_chunk'):
is_final = [msg.is_final for msg in query.resp_messages][0] is_final = [msg.is_final for msg in query.resp_messages][0]
print(is_final)
await query.adapter.reply_message_chunk( await query.adapter.reply_message_chunk(
message_source=query.message_event, message_source=query.message_event,
message_id=query.resp_messages[-1].resp_message_id, message_id=query.resp_messages[-1].resp_message_id,
@@ -58,10 +56,6 @@ class SendResponseBackStage(stage.PipelineStage):
quote_origin=quote_origin, quote_origin=quote_origin,
) )
# await query.adapter.reply_message(
# message_source=query.message_event,
# message=query.resp_message_chain[-1],
# quote_origin=quote_origin,
# )
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)

View File

@@ -25,7 +25,6 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
logger: EventLogger logger: EventLogger
is_stream: bool
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
"""初始化适配器 """初始化适配器
@@ -70,18 +69,23 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
message: platform_message.MessageChain, message: platform_message.MessageChain,
quote_origin: bool = False, quote_origin: bool = False,
is_final: bool = False, is_final: bool = False,
): ):
"""回复消息(流式输出) """回复消息(流式输出)
Args: Args:
message_source (platform.types.MessageEvent): 消息源事件 message_source (platform.types.MessageEvent): 消息源事件
message_id (int): 消息ID message_id (int): 消息ID
message (platform.types.MessageChain): 消息链 message (platform.types.MessageChain): 消息链
quote_origin (bool, optional): 是否引用原消息. Defaults to False. quote_origin (bool, optional): 是否引用原消息. Defaults to False.
is_final (bool, optional): 流式是否结束. Defaults to False.
""" """
raise NotImplementedError raise NotImplementedError
async def create_message_card(self,message_id,event): async def create_message_card(self, message_id:typing.Type[str,int], event:platform_events.MessageEvent) -> bool:
'''创建卡片消息''' """创建卡片消息
Args:
message_id (str): 消息ID
event (platform_events.MessageEvent): 消息源事件
"""
return False return False
async def is_muted(self, group_id: int) -> bool: async def is_muted(self, group_id: int) -> bool:
@@ -118,10 +122,8 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta):
"""异步运行""" """异步运行"""
raise NotImplementedError raise NotImplementedError
async def is_stream_output_supported(self) -> bool: async def is_stream_output_supported(self) -> bool:
"""是否支持流式输出""" """是否支持流式输出"""
self.is_stream = False
return False return False
async def kill(self) -> bool: async def kill(self) -> bool:

View File

@@ -148,7 +148,7 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
message: platform_message.MessageChain, message: platform_message.MessageChain,
quote_origin: bool = False, quote_origin: bool = False,
is_final: bool = False, is_final: bool = False,
): ):
event = await DingTalkEventConverter.yiri2target( event = await DingTalkEventConverter.yiri2target(
message_source, message_source,
) )
@@ -158,13 +158,12 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
content, at = await DingTalkMessageConverter.yiri2target(message) content, at = await DingTalkMessageConverter.yiri2target(message)
card_instance,card_instance_id = self.card_instance_id_dict[message_id] card_instance, card_instance_id = self.card_instance_id_dict[message_id]
# print(card_instance_id) # print(card_instance_id)
await self.bot.send_card_message(card_instance,card_instance_id,content,is_final) await self.bot.send_card_message(card_instance, card_instance_id, content, is_final)
if is_final: if is_final:
self.card_instance_id_dict.pop(message_id) self.card_instance_id_dict.pop(message_id)
async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain):
content = await DingTalkMessageConverter.yiri2target(message) content = await DingTalkMessageConverter.yiri2target(message)
if target_type == 'person': if target_type == 'person':
@@ -174,11 +173,11 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
async def is_stream_output_supported(self) -> bool: async def is_stream_output_supported(self) -> bool:
is_stream = False is_stream = False
if self.config.get("enable-stream-reply", None): if self.config.get('enable-stream-reply', None):
is_stream = True is_stream = True
return is_stream return is_stream
async def create_message_card(self,message_id,event): async def create_message_card(self, message_id, event):
card_template_id = self.config['card_template_id'] card_template_id = self.config['card_template_id']
incoming_message = event.source_platform_object.incoming_message incoming_message = event.source_platform_object.incoming_message
# message_id = incoming_message.message_id # message_id = incoming_message.message_id
@@ -186,7 +185,6 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
self.card_instance_id_dict[message_id] = (card_instance, card_instance_id) self.card_instance_id_dict[message_id] = (card_instance, card_instance_id)
return True return True
def register_listener( def register_listener(
self, self,
event_type: typing.Type[platform_events.Event], event_type: typing.Type[platform_events.Event],
@@ -194,7 +192,6 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter):
): ):
async def on_message(event: DingTalkEvent): async def on_message(event: DingTalkEvent):
try: try:
await self.is_stream_output_supported()
return await callback( return await callback(
await self.event_converter.target2yiri(event, self.config['robot_name']), await self.event_converter.target2yiri(event, self.config['robot_name']),
self, self,

View File

@@ -9,7 +9,6 @@ import re
import base64 import base64
import uuid import uuid
import json import json
import time
import datetime import datetime
import hashlib import hashlib
from Crypto.Cipher import AES from Crypto.Cipher import AES
@@ -345,11 +344,10 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
quart_app: quart.Quart quart_app: quart.Quart
ap: app.Application ap: app.Application
message_id_to_card_id: typing.Dict[str, typing.Tuple[str, int]]
card_id_dict: dict[str, str] card_id_dict: dict[str, str] # 消息id到卡片id的映射便于创建卡片后的发送消息到指定卡片
seq: int seq: int # 用于在发送卡片消息中识别消息顺序直接以seq作为标识
def __init__(self, config: dict, ap: app.Application, logger: EventLogger): def __init__(self, config: dict, ap: app.Application, logger: EventLogger):
self.config = config self.config = config
@@ -357,10 +355,9 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
self.logger = logger self.logger = logger
self.quart_app = quart.Quart(__name__) self.quart_app = quart.Quart(__name__)
self.listeners = {} self.listeners = {}
self.message_id_to_card_id = {}
self.card_id_dict = {} self.card_id_dict = {}
self.seq = 1 self.seq = 1
self.card_id_time = {}
@self.quart_app.route('/lark/callback', methods=['POST']) @self.quart_app.route('/lark/callback', methods=['POST'])
async def lark_callback(): async def lark_callback():
@@ -405,15 +402,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
await self.logger.error(f'Error in lark callback: {traceback.format_exc()}') await self.logger.error(f'Error in lark callback: {traceback.format_exc()}')
return {'code': 500, 'message': 'error'} return {'code': 500, 'message': 'error'}
async def on_message(event: lark_oapi.im.v1.P2ImMessageReceiveV1): async def on_message(event: lark_oapi.im.v1.P2ImMessageReceiveV1):
lb_event = await self.event_converter.target2yiri(event, self.api_client) lb_event = await self.event_converter.target2yiri(event, self.api_client)
await self.listeners[type(lb_event)](lb_event, self) await self.listeners[type(lb_event)](lb_event, self)
@@ -435,51 +424,49 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
async def is_stream_output_supported(self) -> bool: async def is_stream_output_supported(self) -> bool:
is_stream = False is_stream = False
if self.config.get("enable-stream-reply", None): if self.config.get('enable-stream-reply', None):
is_stream = True is_stream = True
return is_stream return is_stream
async def create_card_id(self,message_id): async def create_card_id(self, message_id):
try: try:
is_stream = await self.is_stream_output_supported() self.ap.logger.debug('飞书支持stream输出,创建卡片......')
if is_stream:
self.ap.logger.debug('飞书支持stream输出,创建卡片......')
card_data = {"schema": "2.0", "header": {"title": {"content": "bot", "tag": "plain_text"}}, card_data = {
"body": {"elements": [ 'schema': '2.0',
{"tag": "markdown", "content": "[思考中.....]", "element_id": "markdown_1"}]}, 'header': {'title': {'content': 'bot', 'tag': 'plain_text'}},
"config": {"streaming_mode": True, 'body': {'elements': [{'tag': 'markdown', 'content': '[思考中.....]', 'element_id': 'markdown_1'}]},
"streaming_config": {"print_strategy": "delay"}}} # delay / fast 'config': {'streaming_mode': True, 'streaming_config': {'print_strategy': 'delay'}},
} # delay / fast 创建卡片模板delay 延迟打印fast 实时打印,可以自定义更好看的消息模板
request: CreateCardRequest = CreateCardRequest.builder() \ request: CreateCardRequest = (
.request_body( CreateCardRequest.builder()
CreateCardRequestBody.builder() .request_body(CreateCardRequestBody.builder().type('card_json').data(json.dumps(card_data)).build())
.type("card_json") .build()
.data(json.dumps(card_data)) \ )
.build()
).build()
# 发起请求 # 发起请求
response: CreateCardResponse = self.api_client.cardkit.v1.card.create(request) response: CreateCardResponse = self.api_client.cardkit.v1.card.create(request)
# 处理失败返回 # 处理失败返回
if not response.success(): if not response.success():
raise Exception( raise Exception(
f"client.cardkit.v1.card.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}") f'client.cardkit.v1.card.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
)
self.ap.logger.debug(f'飞书卡片创建成功,卡片ID: {response.data.card_id}') self.ap.logger.debug(f'飞书卡片创建成功,卡片ID: {response.data.card_id}')
self.card_id_dict[message_id] = response.data.card_id self.card_id_dict[message_id] = response.data.card_id
card_id = response.data.card_id card_id = response.data.card_id
return card_id return card_id
except Exception as e: except Exception as e:
self.ap.logger.error(f'飞书卡片创建失败,错误信息: {e}') self.ap.logger.error(f'飞书卡片创建失败,错误信息: {e}')
async def create_message_card(self,message_id,event) -> str: async def create_message_card(self, message_id, event) -> str:
""" """
创建卡片消息。 创建卡片消息。
使用卡片消息是因为普通消息更新次数有限制,而大模型流式返回结果可能很多而超过限制,而飞书卡片没有这个限制 使用卡片消息是因为普通消息更新次数有限制,而大模型流式返回结果可能很多而超过限制,而飞书卡片没有这个限制api免费次数有限
""" """
# message_id = event.message_chain.message_id # message_id = event.message_chain.message_id
@@ -487,7 +474,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
content = { content = {
'type': 'card', 'type': 'card',
'data': {'card_id': card_id, 'template_variable': {'content': 'Thinking...'}}, 'data': {'card_id': card_id, 'template_variable': {'content': 'Thinking...'}},
} } # 当收到消息时发送消息模板,可添加模板变量,详情查看飞书中接口文档
request: ReplyMessageRequest = ( request: ReplyMessageRequest = (
ReplyMessageRequest.builder() ReplyMessageRequest.builder()
.message_id(event.message_chain.message_id) .message_id(event.message_chain.message_id)
@@ -545,7 +532,6 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
f'client.im.v1.message.reply failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}' f'client.im.v1.message.reply failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
) )
async def reply_message_chunk( async def reply_message_chunk(
self, self,
message_source: platform_events.MessageEvent, message_source: platform_events.MessageEvent,
@@ -557,56 +543,50 @@ class LarkAdapter(adapter.MessagePlatformAdapter):
""" """
回复消息变成更新卡片消息 回复消息变成更新卡片消息
""" """
lark_message = await self.message_converter.yiri2target(message, self.api_client)
self.seq += 1 self.seq += 1
if (self.seq - 1) % 8 == 0 or is_final:
lark_message = await self.message_converter.yiri2target(message, self.api_client)
text_message = '' text_message = ''
for ele in lark_message[0]: for ele in lark_message[0]:
if ele['tag'] == 'text': if ele['tag'] == 'text':
text_message += ele['text'] text_message += ele['text']
elif ele['tag'] == 'md': elif ele['tag'] == 'md':
text_message += ele['text'] text_message += ele['text']
print(text_message)
content = { # content = {
'type': 'card_json', # 'type': 'card_json',
'data': {'card_id': self.card_id_dict[message_id], 'elements': {'content': text_message}}, # 'data': {'card_id': self.card_id_dict[message_id], 'elements': {'content': text_message}},
} # }
request: ContentCardElementRequest = ContentCardElementRequest.builder() \ request: ContentCardElementRequest = (
.card_id(self.card_id_dict[message_id]) \ ContentCardElementRequest.builder()
.element_id("markdown_1") \ .card_id(self.card_id_dict[message_id])
.request_body(ContentCardElementRequestBody.builder() .element_id('markdown_1')
# .uuid("a0d69e20-1dd1-458b-k525-dfeca4015204") .request_body(
.content(text_message) ContentCardElementRequestBody.builder()
.sequence(self.seq) # .uuid("a0d69e20-1dd1-458b-k525-dfeca4015204")
.build()) \ .content(text_message)
.build() .sequence(self.seq)
.build()
if is_final: )
self.seq = 1 .build()
self.card_id_dict.pop(message_id)
# 发起请求
response: ContentCardElementResponse = self.api_client.cardkit.v1.card_element.content(request)
# 处理失败返回
if not response.success():
raise Exception(
f'client.im.v1.message.patch failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
) )
return
if is_final:
self.seq = 1 # 消息回复结束之后重置seq
self.card_id_dict.pop(message_id) # 清理已经使用过的卡片
# 发起请求
response: ContentCardElementResponse = self.api_client.cardkit.v1.card_element.content(request)
# 处理失败返回
if not response.success():
raise Exception(
f'client.im.v1.message.patch failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}'
)
return
async def is_muted(self, group_id: int) -> bool: async def is_muted(self, group_id: int) -> bool:
return False return False

View File

@@ -167,7 +167,7 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
await self.listeners[type(lb_event)](lb_event, self) await self.listeners[type(lb_event)](lb_event, self)
await self.is_stream_output_supported() await self.is_stream_output_supported()
except Exception: except Exception:
await self.logger.error(f"Error in telegram callback: {traceback.format_exc()}") await self.logger.error(f'Error in telegram callback: {traceback.format_exc()}')
self.application = ApplicationBuilder().token(self.config['token']).build() self.application = ApplicationBuilder().token(self.config['token']).build()
self.bot = self.application.bot self.bot = self.application.bot
@@ -206,7 +206,6 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
await self.bot.send_message(**args) await self.bot.send_message(**args)
async def reply_message_chunk( async def reply_message_chunk(
self, self,
message_source: platform_events.MessageEvent, message_source: platform_events.MessageEvent,
@@ -214,8 +213,7 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
message: platform_message.MessageChain, message: platform_message.MessageChain,
quote_origin: bool = False, quote_origin: bool = False,
is_final: bool = False, is_final: bool = False,
): ):
assert isinstance(message_source.source_platform_object, Update) assert isinstance(message_source.source_platform_object, Update)
components = await TelegramMessageConverter.yiri2target(message, self.bot) components = await TelegramMessageConverter.yiri2target(message, self.bot)
args = {} args = {}
@@ -240,7 +238,6 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
if self.config['markdown_card'] is True: if self.config['markdown_card'] is True:
args['parse_mode'] = 'MarkdownV2' args['parse_mode'] = 'MarkdownV2'
send_msg = await self.bot.send_message(**args) send_msg = await self.bot.send_message(**args)
send_msg_id = send_msg.message_id send_msg_id = send_msg.message_id
self.msg_stream_id[message_id] = send_msg_id self.msg_stream_id[message_id] = send_msg_id
@@ -264,16 +261,12 @@ class TelegramAdapter(adapter.MessagePlatformAdapter):
if is_final: if is_final:
self.msg_stream_id.pop(message_id) self.msg_stream_id.pop(message_id)
async def is_stream_output_supported(self) -> bool: async def is_stream_output_supported(self) -> bool:
is_stream = False is_stream = False
if self.config.get("enable-stream-reply", None): if self.config.get('enable-stream-reply', None):
is_stream = True is_stream = True
self.is_stream = is_stream
return is_stream return is_stream
async def is_muted(self, group_id: int) -> bool: async def is_muted(self, group_id: int) -> bool:
return False return False

View File

@@ -17,14 +17,13 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
"""OpenAI ChatCompletion API 请求器""" """OpenAI ChatCompletion API 请求器"""
client: openai.AsyncClient client: openai.AsyncClient
is_content:bool is_content: bool
default_config: dict[str, typing.Any] = { default_config: dict[str, typing.Any] = {
'base_url': 'https://api.openai.com/v1', 'base_url': 'https://api.openai.com/v1',
'timeout': 120, 'timeout': 120,
} }
async def initialize(self): async def initialize(self):
self.client = openai.AsyncClient( self.client = openai.AsyncClient(
api_key='', api_key='',
@@ -46,7 +45,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
args: dict, args: dict,
extra_body: dict = {}, extra_body: dict = {},
) -> chat_completion.ChatCompletion: ) -> chat_completion.ChatCompletion:
async for chunk in await self.client.chat.completions.create(**args, extra_body=extra_body): async for chunk in await self.client.chat.completions.create(**args, extra_body=extra_body):
yield chunk yield chunk
@@ -66,11 +64,12 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
# deepseek的reasoner模型 # deepseek的reasoner模型
if pipeline_config['trigger'].get('misc', '').get('remove_think'): if pipeline_config['trigger'].get('misc', '').get('remove_think'):
pass pass
else: else:
if reasoning_content is not None : if reasoning_content is not None:
chatcmpl_message['content'] = '<think>\n' + reasoning_content + '\n</think>\n' + chatcmpl_message['content'] chatcmpl_message['content'] = (
'<think>\n' + reasoning_content + '\n</think>\n' + chatcmpl_message['content']
)
message = llm_entities.Message(**chatcmpl_message) message = llm_entities.Message(**chatcmpl_message)
@@ -82,7 +81,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
chat_completion: chat_completion.ChatCompletion, chat_completion: chat_completion.ChatCompletion,
idx: int, idx: int,
) -> llm_entities.MessageChunk: ) -> llm_entities.MessageChunk:
# 处理流式chunk和完整响应的差异 # 处理流式chunk和完整响应的差异
# print(chat_completion.choices[0]) # print(chat_completion.choices[0])
if hasattr(chat_completion, 'choices'): if hasattr(chat_completion, 'choices'):
@@ -98,7 +96,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
if 'role' not in delta or delta['role'] is None: if 'role' not in delta or delta['role'] is None:
delta['role'] = 'assistant' delta['role'] = 'assistant'
reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None
delta['content'] = '' if delta['content'] is None else delta['content'] delta['content'] = '' if delta['content'] is None else delta['content']
@@ -106,13 +103,13 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
# deepseek的reasoner模型 # deepseek的reasoner模型
if pipeline_config['trigger'].get('misc', '').get('remove_think'): if pipeline_config['trigger'].get('misc', '').get('remove_think'):
if reasoning_content is not None : if reasoning_content is not None:
pass pass
else: else:
delta['content'] = delta['content'] delta['content'] = delta['content']
else: else:
if reasoning_content is not None and idx == 0: if reasoning_content is not None and idx == 0:
delta['content'] += f'<think>\n{reasoning_content}' delta['content'] += f'<think>\n{reasoning_content}'
elif reasoning_content is None: elif reasoning_content is None:
if self.is_content: if self.is_content:
delta['content'] = delta['content'] delta['content'] = delta['content']
@@ -122,7 +119,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
else: else:
delta['content'] += reasoning_content delta['content'] += reasoning_content
message = llm_entities.MessageChunk(**delta) message = llm_entities.MessageChunk(**delta)
return message return message
@@ -135,9 +131,10 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
use_funcs: list[tools_entities.LLMFunction] = None, use_funcs: list[tools_entities.LLMFunction] = None,
stream: bool = False, stream: bool = False,
extra_args: dict[str, typing.Any] = {}, extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: ) ->llm_entities.MessageChunk:
self.client.api_key = use_model.token_mgr.get_token() self.client.api_key = use_model.token_mgr.get_token()
args = {} args = {}
args['model'] = use_model.model_entity.name args['model'] = use_model.model_entity.name
@@ -163,14 +160,14 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
if stream: if stream:
current_content = '' current_content = ''
args["stream"] = True args['stream'] = True
chunk_idx = 0 chunk_idx = 0
self.is_content = False self.is_content = False
tool_calls_map: dict[str, llm_entities.ToolCall] = {} tool_calls_map: dict[str, llm_entities.ToolCall] = {}
pipeline_config = query.pipeline_config pipeline_config = query.pipeline_config
async for chunk in self._req_stream(args, extra_body=extra_args): async for chunk in self._req_stream(args, extra_body=extra_args):
# 处理流式消息 # 处理流式消息
delta_message = await self._make_msg_chunk(pipeline_config,chunk,chunk_idx) delta_message = await self._make_msg_chunk(pipeline_config, chunk, chunk_idx)
if delta_message.content: if delta_message.content:
current_content += delta_message.content current_content += delta_message.content
delta_message.content = current_content delta_message.content = current_content
@@ -182,15 +179,13 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
id=tool_call.id, id=tool_call.id,
type=tool_call.type, type=tool_call.type,
function=llm_entities.FunctionCall( function=llm_entities.FunctionCall(
name=tool_call.function.name if tool_call.function else '', name=tool_call.function.name if tool_call.function else '', arguments=''
arguments=''
), ),
) )
if tool_call.function and tool_call.function.arguments: if tool_call.function and tool_call.function.arguments:
# 流式处理中工具调用参数可能分多个chunk返回需要追加而不是覆盖 # 流式处理中工具调用参数可能分多个chunk返回需要追加而不是覆盖
tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
chunk_idx += 1 chunk_idx += 1
chunk_choices = getattr(chunk, 'choices', None) chunk_choices = getattr(chunk, 'choices', None)
if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None):
@@ -198,11 +193,9 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
delta_message.content = current_content delta_message.content = current_content
if chunk_idx % 64 == 0 or delta_message.is_final: if chunk_idx % 64 == 0 or delta_message.is_final:
yield delta_message yield delta_message
# return # return
async def _closure( async def _closure(
self, self,
query: core_entities.Query, query: core_entities.Query,
@@ -211,7 +204,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
use_funcs: list[tools_entities.LLMFunction] = None, use_funcs: list[tools_entities.LLMFunction] = None,
stream: bool = False, stream: bool = False,
extra_args: dict[str, typing.Any] = {}, extra_args: dict[str, typing.Any] = {},
) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: ) -> llm_entities.Message:
self.client.api_key = use_model.token_mgr.get_token() self.client.api_key = use_model.token_mgr.get_token()
args = {} args = {}
@@ -237,22 +230,15 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
args['messages'] = messages args['messages'] = messages
# 发送请求 # 发送请求
resp = await self._req(args, extra_body=extra_args) resp = await self._req(args, extra_body=extra_args)
# 处理请求结果 # 处理请求结果
pipeline_config = query.pipeline_config pipeline_config = query.pipeline_config
message = await self._make_msg(resp,pipeline_config) message = await self._make_msg(resp, pipeline_config)
return message return message
async def invoke_llm( async def invoke_llm(
self, self,
query: core_entities.Query, query: core_entities.Query,
@@ -273,7 +259,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
req_messages.append(msg_dict) req_messages.append(msg_dict)
try: try:
msg = await self._closure( msg = await self._closure(
query=query, query=query,
req_messages=req_messages, req_messages=req_messages,
@@ -334,7 +319,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
funcs: typing.List[tools_entities.LLMFunction] = None, funcs: typing.List[tools_entities.LLMFunction] = None,
stream: bool = False, stream: bool = False,
extra_args: dict[str, typing.Any] = {}, extra_args: dict[str, typing.Any] = {},
) -> llm_entities.MessageChunk: ) -> llm_entities.MessageChunk:
req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行
for m in messages: for m in messages:
msg_dict = m.dict(exclude_none=True) msg_dict = m.dict(exclude_none=True)

View File

@@ -55,6 +55,6 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions):
raise errors.RequesterError('接口返回为空,请确定模型提供商服务是否正常') raise errors.RequesterError('接口返回为空,请确定模型提供商服务是否正常')
pipeline_config = query.pipeline_config pipeline_config = query.pipeline_config
# 处理请求结果 # 处理请求结果
message = await self._make_msg(resp,pipeline_config) message = await self._make_msg(resp, pipeline_config)
return message return message

View File

@@ -185,8 +185,6 @@ class DashScopeAPIRunner(runner.RequestRunner):
# 将参考资料替换到文本中 # 将参考资料替换到文本中
pending_content = self._replace_references(pending_content, references_dict) pending_content = self._replace_references(pending_content, references_dict)
yield llm_entities.Message( yield llm_entities.Message(
role='assistant', role='assistant',
content=pending_content, content=pending_content,
@@ -261,13 +259,11 @@ class DashScopeAPIRunner(runner.RequestRunner):
role='assistant', role='assistant',
content=pending_content, content=pending_content,
is_final=is_final, is_final=is_final,
) )
# 保存当前会话的session_id用于下次对话的语境 # 保存当前会话的session_id用于下次对话的语境
query.session.using_conversation.uuid = stream_output.get('session_id') query.session.using_conversation.uuid = stream_output.get('session_id')
else: else:
for chunk in response: for chunk in response:
if chunk.get('status_code') != 200: if chunk.get('status_code') != 200:

View File

@@ -148,7 +148,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
if mode == 'workflow': if mode == 'workflow':
if chunk['event'] == 'node_finished': if chunk['event'] == 'node_finished':
if not is_stream: if not is_stream:
if chunk['data']['node_type'] == 'answer': if chunk['data']['node_type'] == 'answer':
yield llm_entities.Message( yield llm_entities.Message(
role='assistant', role='assistant',
@@ -274,7 +273,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
content=self._try_convert_thinking(pending_agent_message), content=self._try_convert_thinking(pending_agent_message),
) )
if chunk['event'] == 'agent_thought': if chunk['event'] == 'agent_thought':
if chunk['tool'] != '' and chunk['observation'] != '': # 工具调用结果,跳过 if chunk['tool'] != '' and chunk['observation'] != '': # 工具调用结果,跳过
continue continue

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import json import json
import copy import copy
from ssl import ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE
import typing import typing
from .. import runner from .. import runner
from ...core import entities as core_entities from ...core import entities as core_entities
@@ -30,11 +29,14 @@ class LocalAgentRunner(runner.RequestRunner):
class ToolCallTracker: class ToolCallTracker:
"""工具调用追踪器""" """工具调用追踪器"""
def __init__(self): def __init__(self):
self.active_calls: dict[str,dict] = {} self.active_calls: dict[str, dict] = {}
self.completed_calls: list[llm_entities.ToolCall] = [] self.completed_calls: list[llm_entities.ToolCall] = []
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message | llm_entities.MessageChunk, None]: async def run(
self, query: core_entities.Query
) -> typing.AsyncGenerator[llm_entities.Message | llm_entities.MessageChunk, None]:
"""运行请求""" """运行请求"""
pending_tool_calls = [] pending_tool_calls = []
@@ -89,16 +91,14 @@ class LocalAgentRunner(runner.RequestRunner):
is_stream = query.adapter.is_stream_output_supported() is_stream = query.adapter.is_stream_output_supported()
try: try:
# print(await query.adapter.is_stream_output_supported())
is_stream = await query.adapter.is_stream_output_supported() is_stream = await query.adapter.is_stream_output_supported()
except AttributeError: except AttributeError:
is_stream = False is_stream = False
# while True:
# pass
if not is_stream: if not is_stream:
# 非流式输出,直接请求 # 非流式输出,直接请求
# print(123)
msg = await query.use_llm_model.requester.invoke_llm( msg = await query.use_llm_model.requester.invoke_llm(
query, query,
query.use_llm_model, query.use_llm_model,
@@ -108,7 +108,6 @@ class LocalAgentRunner(runner.RequestRunner):
) )
yield msg yield msg
final_msg = msg final_msg = msg
print(final_msg)
else: else:
# 流式输出,需要处理工具调用 # 流式输出,需要处理工具调用
tool_calls_map: dict[str, llm_entities.ToolCall] = {} tool_calls_map: dict[str, llm_entities.ToolCall] = {}
@@ -122,27 +121,26 @@ class LocalAgentRunner(runner.RequestRunner):
): ):
assert isinstance(msg, llm_entities.MessageChunk) assert isinstance(msg, llm_entities.MessageChunk)
yield msg yield msg
# if msg.tool_calls: if msg.tool_calls:
# for tool_call in msg.tool_calls: for tool_call in msg.tool_calls:
# if tool_call.id not in tool_calls_map: if tool_call.id not in tool_calls_map:
# tool_calls_map[tool_call.id] = llm_entities.ToolCall( tool_calls_map[tool_call.id] = llm_entities.ToolCall(
# id=tool_call.id, id=tool_call.id,
# type=tool_call.type, type=tool_call.type,
# function=llm_entities.FunctionCall( function=llm_entities.FunctionCall(
# name=tool_call.function.name if tool_call.function else '', name=tool_call.function.name if tool_call.function else '',
# arguments='' arguments=''
# ), ),
# ) )
# if tool_call.function and tool_call.function.arguments: if tool_call.function and tool_call.function.arguments:
# # 流式处理中工具调用参数可能分多个chunk返回需要追加而不是覆盖 # 流式处理中工具调用参数可能分多个chunk返回需要追加而不是覆盖
# tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments
final_msg = llm_entities.Message( final_msg = llm_entities.Message(
role=msg.role, role=msg.role,
content=msg.all_content, content=msg.all_content,
tool_calls=list(tool_calls_map.values()), tool_calls=list(tool_calls_map.values()),
) )
pending_tool_calls = final_msg.tool_calls pending_tool_calls = final_msg.tool_calls
req_messages.append(final_msg) req_messages.append(final_msg)
@@ -193,8 +191,7 @@ class LocalAgentRunner(runner.RequestRunner):
id=tool_call.id, id=tool_call.id,
type=tool_call.type, type=tool_call.type,
function=llm_entities.FunctionCall( function=llm_entities.FunctionCall(
name=tool_call.function.name if tool_call.function else '', name=tool_call.function.name if tool_call.function else '', arguments=''
arguments=''
), ),
) )
if tool_call.function and tool_call.function.arguments: if tool_call.function and tool_call.function.arguments:
@@ -206,7 +203,6 @@ class LocalAgentRunner(runner.RequestRunner):
tool_calls=list(tool_calls_map.values()), tool_calls=list(tool_calls_map.values()),
) )
else: else:
print("非流式")
# 处理完所有调用,再次请求 # 处理完所有调用,再次请求
msg = await query.use_llm_model.requester.invoke_llm( msg = await query.use_llm_model.requester.invoke_llm(
query, query,