From a9776b7b53c45b3b8ebb52e0638c631e79642dda Mon Sep 17 00:00:00 2001 From: Dong_master <2213070223@qq.com> Date: Tue, 29 Jul 2025 23:09:02 +0800 Subject: [PATCH] fix:del some print ,and amend respback on stream judge ,and del in dingtalk this is_stream_output_supported() use --- pkg/pipeline/respback/respback.py | 12 +- pkg/platform/adapter.py | 18 +- pkg/platform/sources/dingtalk.py | 13 +- pkg/platform/sources/lark.py | 154 ++++++++---------- pkg/platform/sources/telegram.py | 13 +- pkg/provider/modelmgr/requesters/chatcmpl.py | 47 ++---- .../modelmgr/requesters/deepseekchatcmpl.py | 2 +- pkg/provider/runners/dashscopeapi.py | 4 - pkg/provider/runners/difysvapi.py | 2 - pkg/provider/runners/localagent.py | 48 +++--- 10 files changed, 127 insertions(+), 186 deletions(-) diff --git a/pkg/pipeline/respback/respback.py b/pkg/pipeline/respback/respback.py index 9a410b3f..f4153218 100644 --- a/pkg/pipeline/respback/respback.py +++ b/pkg/pipeline/respback/respback.py @@ -39,11 +39,9 @@ class SendResponseBackStage(stage.PipelineStage): quote_origin = query.pipeline_config['output']['misc']['quote-origin'] - has_chunks = any(isinstance(msg, llm_entities.MessageChunk) for msg in query.resp_messages) - print(has_chunks) - if has_chunks and hasattr(query.adapter,'reply_message_chunk'): + # has_chunks = any(isinstance(msg, llm_entities.MessageChunk) for msg in query.resp_messages) + if await query.adapter.is_stream_output_supported(): is_final = [msg.is_final for msg in query.resp_messages][0] - print(is_final) await query.adapter.reply_message_chunk( message_source=query.message_event, message_id=query.resp_messages[-1].resp_message_id, @@ -58,10 +56,6 @@ class SendResponseBackStage(stage.PipelineStage): quote_origin=quote_origin, ) - # await query.adapter.reply_message( - # message_source=query.message_event, - # message=query.resp_message_chain[-1], - # quote_origin=quote_origin, - # ) + return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query) diff --git a/pkg/platform/adapter.py b/pkg/platform/adapter.py index d4b48ef6..e4369efb 100644 --- a/pkg/platform/adapter.py +++ b/pkg/platform/adapter.py @@ -25,7 +25,6 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): logger: EventLogger - is_stream: bool def __init__(self, config: dict, ap: app.Application, logger: EventLogger): """初始化适配器 @@ -62,26 +61,31 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): quote_origin (bool, optional): 是否引用原消息. Defaults to False. """ raise NotImplementedError - + async def reply_message_chunk( self, - message_source: platform_events.MessageEvent, + message_source: platform_events.MessageEvent, message_id: int, message: platform_message.MessageChain, quote_origin: bool = False, is_final: bool = False, - ): + ): """回复消息(流式输出) Args: message_source (platform.types.MessageEvent): 消息源事件 message_id (int): 消息ID message (platform.types.MessageChain): 消息链 quote_origin (bool, optional): 是否引用原消息. Defaults to False. + is_final (bool, optional): 流式是否结束. Defaults to False. """ raise NotImplementedError - async def create_message_card(self,message_id,event): - '''创建卡片消息''' + async def create_message_card(self, message_id:typing.Type[str,int], event:platform_events.MessageEvent) -> bool: + """创建卡片消息 + Args: + message_id (str): 消息ID + event (platform_events.MessageEvent): 消息源事件 + """ return False async def is_muted(self, group_id: int) -> bool: @@ -117,11 +121,9 @@ class MessagePlatformAdapter(metaclass=abc.ABCMeta): async def run_async(self): """异步运行""" raise NotImplementedError - async def is_stream_output_supported(self) -> bool: """是否支持流式输出""" - self.is_stream = False return False async def kill(self) -> bool: diff --git a/pkg/platform/sources/dingtalk.py b/pkg/platform/sources/dingtalk.py index d1859ab5..9f834f2a 100644 --- a/pkg/platform/sources/dingtalk.py +++ b/pkg/platform/sources/dingtalk.py @@ -148,7 +148,7 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): message: platform_message.MessageChain, quote_origin: bool = False, is_final: bool = False, - ): + ): event = await DingTalkEventConverter.yiri2target( message_source, ) @@ -158,13 +158,12 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): content, at = await DingTalkMessageConverter.yiri2target(message) - card_instance,card_instance_id = self.card_instance_id_dict[message_id] + card_instance, card_instance_id = self.card_instance_id_dict[message_id] # print(card_instance_id) - await self.bot.send_card_message(card_instance,card_instance_id,content,is_final) + await self.bot.send_card_message(card_instance, card_instance_id, content, is_final) if is_final: self.card_instance_id_dict.pop(message_id) - async def send_message(self, target_type: str, target_id: str, message: platform_message.MessageChain): content = await DingTalkMessageConverter.yiri2target(message) if target_type == 'person': @@ -174,11 +173,11 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): async def is_stream_output_supported(self) -> bool: is_stream = False - if self.config.get("enable-stream-reply", None): + if self.config.get('enable-stream-reply', None): is_stream = True return is_stream - async def create_message_card(self,message_id,event): + async def create_message_card(self, message_id, event): card_template_id = self.config['card_template_id'] incoming_message = event.source_platform_object.incoming_message # message_id = incoming_message.message_id @@ -186,7 +185,6 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): self.card_instance_id_dict[message_id] = (card_instance, card_instance_id) return True - def register_listener( self, event_type: typing.Type[platform_events.Event], @@ -194,7 +192,6 @@ class DingTalkAdapter(adapter.MessagePlatformAdapter): ): async def on_message(event: DingTalkEvent): try: - await self.is_stream_output_supported() return await callback( await self.event_converter.target2yiri(event, self.config['robot_name']), self, diff --git a/pkg/platform/sources/lark.py b/pkg/platform/sources/lark.py index fb3d0c48..0d7fc0fb 100644 --- a/pkg/platform/sources/lark.py +++ b/pkg/platform/sources/lark.py @@ -9,7 +9,6 @@ import re import base64 import uuid import json -import time import datetime import hashlib from Crypto.Cipher import AES @@ -344,12 +343,11 @@ class LarkAdapter(adapter.MessagePlatformAdapter): config: dict quart_app: quart.Quart ap: app.Application - - message_id_to_card_id: typing.Dict[str, typing.Tuple[str, int]] - card_id_dict: dict[str, str] - seq: int + card_id_dict: dict[str, str] # 消息id到卡片id的映射,便于创建卡片后的发送消息到指定卡片 + + seq: int # 用于在发送卡片消息中识别消息顺序,直接以seq作为标识 def __init__(self, config: dict, ap: app.Application, logger: EventLogger): self.config = config @@ -357,10 +355,9 @@ class LarkAdapter(adapter.MessagePlatformAdapter): self.logger = logger self.quart_app = quart.Quart(__name__) self.listeners = {} - self.message_id_to_card_id = {} self.card_id_dict = {} self.seq = 1 - self.card_id_time = {} + @self.quart_app.route('/lark/callback', methods=['POST']) async def lark_callback(): @@ -405,15 +402,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter): await self.logger.error(f'Error in lark callback: {traceback.format_exc()}') return {'code': 500, 'message': 'error'} - - - - - - - async def on_message(event: lark_oapi.im.v1.P2ImMessageReceiveV1): - lb_event = await self.event_converter.target2yiri(event, self.api_client) await self.listeners[type(lb_event)](lb_event, self) @@ -435,51 +424,49 @@ class LarkAdapter(adapter.MessagePlatformAdapter): async def is_stream_output_supported(self) -> bool: is_stream = False - if self.config.get("enable-stream-reply", None): + if self.config.get('enable-stream-reply', None): is_stream = True return is_stream - async def create_card_id(self,message_id): + async def create_card_id(self, message_id): try: - is_stream = await self.is_stream_output_supported() - if is_stream: - self.ap.logger.debug('飞书支持stream输出,创建卡片......') + self.ap.logger.debug('飞书支持stream输出,创建卡片......') - card_data = {"schema": "2.0", "header": {"title": {"content": "bot", "tag": "plain_text"}}, - "body": {"elements": [ - {"tag": "markdown", "content": "[思考中.....]", "element_id": "markdown_1"}]}, - "config": {"streaming_mode": True, - "streaming_config": {"print_strategy": "delay"}}} # delay / fast + card_data = { + 'schema': '2.0', + 'header': {'title': {'content': 'bot', 'tag': 'plain_text'}}, + 'body': {'elements': [{'tag': 'markdown', 'content': '[思考中.....]', 'element_id': 'markdown_1'}]}, + 'config': {'streaming_mode': True, 'streaming_config': {'print_strategy': 'delay'}}, + } # delay / fast 创建卡片模板,delay 延迟打印,fast 实时打印,可以自定义更好看的消息模板 - request: CreateCardRequest = CreateCardRequest.builder() \ - .request_body( - CreateCardRequestBody.builder() - .type("card_json") - .data(json.dumps(card_data)) \ - .build() - ).build() + request: CreateCardRequest = ( + CreateCardRequest.builder() + .request_body(CreateCardRequestBody.builder().type('card_json').data(json.dumps(card_data)).build()) + .build() + ) - # 发起请求 - response: CreateCardResponse = self.api_client.cardkit.v1.card.create(request) + # 发起请求 + response: CreateCardResponse = self.api_client.cardkit.v1.card.create(request) - # 处理失败返回 - if not response.success(): - raise Exception( - f"client.cardkit.v1.card.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}") + # 处理失败返回 + if not response.success(): + raise Exception( + f'client.cardkit.v1.card.create failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}' + ) - self.ap.logger.debug(f'飞书卡片创建成功,卡片ID: {response.data.card_id}') - self.card_id_dict[message_id] = response.data.card_id + self.ap.logger.debug(f'飞书卡片创建成功,卡片ID: {response.data.card_id}') + self.card_id_dict[message_id] = response.data.card_id - card_id = response.data.card_id + card_id = response.data.card_id return card_id except Exception as e: self.ap.logger.error(f'飞书卡片创建失败,错误信息: {e}') - async def create_message_card(self,message_id,event) -> str: + async def create_message_card(self, message_id, event) -> str: """ 创建卡片消息。 - 使用卡片消息是因为普通消息更新次数有限制,而大模型流式返回结果可能很多而超过限制,而飞书卡片没有这个限制 + 使用卡片消息是因为普通消息更新次数有限制,而大模型流式返回结果可能很多而超过限制,而飞书卡片没有这个限制(api免费次数有限) """ # message_id = event.message_chain.message_id @@ -487,7 +474,7 @@ class LarkAdapter(adapter.MessagePlatformAdapter): content = { 'type': 'card', 'data': {'card_id': card_id, 'template_variable': {'content': 'Thinking...'}}, - } + } # 当收到消息时发送消息模板,可添加模板变量,详情查看飞书中接口文档 request: ReplyMessageRequest = ( ReplyMessageRequest.builder() .message_id(event.message_chain.message_id) @@ -545,7 +532,6 @@ class LarkAdapter(adapter.MessagePlatformAdapter): f'client.im.v1.message.reply failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}' ) - async def reply_message_chunk( self, message_source: platform_events.MessageEvent, @@ -557,56 +543,50 @@ class LarkAdapter(adapter.MessagePlatformAdapter): """ 回复消息变成更新卡片消息 """ - lark_message = await self.message_converter.yiri2target(message, self.api_client) - - self.seq += 1 + if (self.seq - 1) % 8 == 0 or is_final: + lark_message = await self.message_converter.yiri2target(message, self.api_client) - text_message = '' - for ele in lark_message[0]: - if ele['tag'] == 'text': - text_message += ele['text'] - elif ele['tag'] == 'md': - text_message += ele['text'] - print(text_message) + text_message = '' + for ele in lark_message[0]: + if ele['tag'] == 'text': + text_message += ele['text'] + elif ele['tag'] == 'md': + text_message += ele['text'] - content = { - 'type': 'card_json', - 'data': {'card_id': self.card_id_dict[message_id], 'elements': {'content': text_message}}, - } + # content = { + # 'type': 'card_json', + # 'data': {'card_id': self.card_id_dict[message_id], 'elements': {'content': text_message}}, + # } - request: ContentCardElementRequest = ContentCardElementRequest.builder() \ - .card_id(self.card_id_dict[message_id]) \ - .element_id("markdown_1") \ - .request_body(ContentCardElementRequestBody.builder() - # .uuid("a0d69e20-1dd1-458b-k525-dfeca4015204") - .content(text_message) - .sequence(self.seq) - .build()) \ - .build() - - if is_final: - self.seq = 1 - self.card_id_dict.pop(message_id) - # 发起请求 - response: ContentCardElementResponse = self.api_client.cardkit.v1.card_element.content(request) - - - # 处理失败返回 - if not response.success(): - raise Exception( - f'client.im.v1.message.patch failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}' + request: ContentCardElementRequest = ( + ContentCardElementRequest.builder() + .card_id(self.card_id_dict[message_id]) + .element_id('markdown_1') + .request_body( + ContentCardElementRequestBody.builder() + # .uuid("a0d69e20-1dd1-458b-k525-dfeca4015204") + .content(text_message) + .sequence(self.seq) + .build() + ) + .build() ) - return - - - - - + if is_final: + self.seq = 1 # 消息回复结束之后重置seq + self.card_id_dict.pop(message_id) # 清理已经使用过的卡片 + # 发起请求 + response: ContentCardElementResponse = self.api_client.cardkit.v1.card_element.content(request) + # 处理失败返回 + if not response.success(): + raise Exception( + f'client.im.v1.message.patch failed, code: {response.code}, msg: {response.msg}, log_id: {response.get_log_id()}, resp: \n{json.dumps(json.loads(response.raw.content), indent=4, ensure_ascii=False)}' + ) + return async def is_muted(self, group_id: int) -> bool: return False @@ -658,4 +638,4 @@ class LarkAdapter(adapter.MessagePlatformAdapter): # 所以要设置_auto_reconnect=False,让其不重连。 self.bot._auto_reconnect = False await self.bot._disconnect() - return False + return False \ No newline at end of file diff --git a/pkg/platform/sources/telegram.py b/pkg/platform/sources/telegram.py index efc7890f..22ef63e8 100644 --- a/pkg/platform/sources/telegram.py +++ b/pkg/platform/sources/telegram.py @@ -167,7 +167,7 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): await self.listeners[type(lb_event)](lb_event, self) await self.is_stream_output_supported() except Exception: - await self.logger.error(f"Error in telegram callback: {traceback.format_exc()}") + await self.logger.error(f'Error in telegram callback: {traceback.format_exc()}') self.application = ApplicationBuilder().token(self.config['token']).build() self.bot = self.application.bot @@ -206,7 +206,6 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): await self.bot.send_message(**args) - async def reply_message_chunk( self, message_source: platform_events.MessageEvent, @@ -214,8 +213,7 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): message: platform_message.MessageChain, quote_origin: bool = False, is_final: bool = False, - ): - + ): assert isinstance(message_source.source_platform_object, Update) components = await TelegramMessageConverter.yiri2target(message, self.bot) args = {} @@ -240,7 +238,6 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): if self.config['markdown_card'] is True: args['parse_mode'] = 'MarkdownV2' - send_msg = await self.bot.send_message(**args) send_msg_id = send_msg.message_id self.msg_stream_id[message_id] = send_msg_id @@ -264,16 +261,12 @@ class TelegramAdapter(adapter.MessagePlatformAdapter): if is_final: self.msg_stream_id.pop(message_id) - async def is_stream_output_supported(self) -> bool: is_stream = False - if self.config.get("enable-stream-reply", None): + if self.config.get('enable-stream-reply', None): is_stream = True - self.is_stream = is_stream - return is_stream - async def is_muted(self, group_id: int) -> bool: return False diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index 8e350bf6..d5c3b90a 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -17,14 +17,13 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): """OpenAI ChatCompletion API 请求器""" client: openai.AsyncClient - is_content:bool + is_content: bool default_config: dict[str, typing.Any] = { 'base_url': 'https://api.openai.com/v1', 'timeout': 120, } - async def initialize(self): self.client = openai.AsyncClient( api_key='', @@ -46,7 +45,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): args: dict, extra_body: dict = {}, ) -> chat_completion.ChatCompletion: - async for chunk in await self.client.chat.completions.create(**args, extra_body=extra_body): yield chunk @@ -66,23 +64,23 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): # deepseek的reasoner模型 if pipeline_config['trigger'].get('misc', '').get('remove_think'): - pass else: - if reasoning_content is not None : - chatcmpl_message['content'] = '\n' + reasoning_content + '\n\n' + chatcmpl_message['content'] + if reasoning_content is not None: + chatcmpl_message['content'] = ( + '\n' + reasoning_content + '\n\n' + chatcmpl_message['content'] + ) message = llm_entities.Message(**chatcmpl_message) return message - + async def _make_msg_chunk( self, pipeline_config: dict[str, typing.Any], chat_completion: chat_completion.ChatCompletion, idx: int, ) -> llm_entities.MessageChunk: - # 处理流式chunk和完整响应的差异 # print(chat_completion.choices[0]) if hasattr(chat_completion, 'choices'): @@ -98,7 +96,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): if 'role' not in delta or delta['role'] is None: delta['role'] = 'assistant' - reasoning_content = delta['reasoning_content'] if 'reasoning_content' in delta else None delta['content'] = '' if delta['content'] is None else delta['content'] @@ -106,13 +103,13 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): # deepseek的reasoner模型 if pipeline_config['trigger'].get('misc', '').get('remove_think'): - if reasoning_content is not None : + if reasoning_content is not None: pass else: delta['content'] = delta['content'] else: if reasoning_content is not None and idx == 0: - delta['content'] += f'\n{reasoning_content}' + delta['content'] += f'\n{reasoning_content}' elif reasoning_content is None: if self.is_content: delta['content'] = delta['content'] @@ -122,7 +119,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): else: delta['content'] += reasoning_content - message = llm_entities.MessageChunk(**delta) return message @@ -135,9 +131,10 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): use_funcs: list[tools_entities.LLMFunction] = None, stream: bool = False, extra_args: dict[str, typing.Any] = {}, - ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: + ) ->llm_entities.MessageChunk: self.client.api_key = use_model.token_mgr.get_token() + args = {} args['model'] = use_model.model_entity.name @@ -163,14 +160,14 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): if stream: current_content = '' - args["stream"] = True + args['stream'] = True chunk_idx = 0 self.is_content = False tool_calls_map: dict[str, llm_entities.ToolCall] = {} pipeline_config = query.pipeline_config async for chunk in self._req_stream(args, extra_body=extra_args): # 处理流式消息 - delta_message = await self._make_msg_chunk(pipeline_config,chunk,chunk_idx) + delta_message = await self._make_msg_chunk(pipeline_config, chunk, chunk_idx) if delta_message.content: current_content += delta_message.content delta_message.content = current_content @@ -182,15 +179,13 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): id=tool_call.id, type=tool_call.type, function=llm_entities.FunctionCall( - name=tool_call.function.name if tool_call.function else '', - arguments='' + name=tool_call.function.name if tool_call.function else '', arguments='' ), ) if tool_call.function and tool_call.function.arguments: # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments - chunk_idx += 1 chunk_choices = getattr(chunk, 'choices', None) if chunk_choices and getattr(chunk_choices[0], 'finish_reason', None): @@ -198,11 +193,9 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): delta_message.content = current_content if chunk_idx % 64 == 0 or delta_message.is_final: - yield delta_message # return - async def _closure( self, query: core_entities.Query, @@ -211,7 +204,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): use_funcs: list[tools_entities.LLMFunction] = None, stream: bool = False, extra_args: dict[str, typing.Any] = {}, - ) -> llm_entities.Message | typing.AsyncGenerator[llm_entities.MessageChunk, None]: + ) -> llm_entities.Message: self.client.api_key = use_model.token_mgr.get_token() args = {} @@ -237,22 +230,15 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): args['messages'] = messages - - # 发送请求 resp = await self._req(args, extra_body=extra_args) # 处理请求结果 pipeline_config = query.pipeline_config - message = await self._make_msg(resp,pipeline_config) - + message = await self._make_msg(resp, pipeline_config) return message - - - - async def invoke_llm( self, query: core_entities.Query, @@ -273,7 +259,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): req_messages.append(msg_dict) try: - msg = await self._closure( query=query, req_messages=req_messages, @@ -334,7 +319,7 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): funcs: typing.List[tools_entities.LLMFunction] = None, stream: bool = False, extra_args: dict[str, typing.Any] = {}, - ) -> llm_entities.MessageChunk: + ) -> llm_entities.MessageChunk: req_messages = [] # req_messages 仅用于类内,外部同步由 query.messages 进行 for m in messages: msg_dict = m.dict(exclude_none=True) diff --git a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py index f57f624f..d75d0fb6 100644 --- a/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py +++ b/pkg/provider/modelmgr/requesters/deepseekchatcmpl.py @@ -55,6 +55,6 @@ class DeepseekChatCompletions(chatcmpl.OpenAIChatCompletions): raise errors.RequesterError('接口返回为空,请确定模型提供商服务是否正常') pipeline_config = query.pipeline_config # 处理请求结果 - message = await self._make_msg(resp,pipeline_config) + message = await self._make_msg(resp, pipeline_config) return message diff --git a/pkg/provider/runners/dashscopeapi.py b/pkg/provider/runners/dashscopeapi.py index fe72b0a8..9bb5824c 100644 --- a/pkg/provider/runners/dashscopeapi.py +++ b/pkg/provider/runners/dashscopeapi.py @@ -185,8 +185,6 @@ class DashScopeAPIRunner(runner.RequestRunner): # 将参考资料替换到文本中 pending_content = self._replace_references(pending_content, references_dict) - - yield llm_entities.Message( role='assistant', content=pending_content, @@ -261,13 +259,11 @@ class DashScopeAPIRunner(runner.RequestRunner): role='assistant', content=pending_content, is_final=is_final, - ) # 保存当前会话的session_id用于下次对话的语境 query.session.using_conversation.uuid = stream_output.get('session_id') - else: for chunk in response: if chunk.get('status_code') != 200: diff --git a/pkg/provider/runners/difysvapi.py b/pkg/provider/runners/difysvapi.py index 7c7d81ad..8182cc54 100644 --- a/pkg/provider/runners/difysvapi.py +++ b/pkg/provider/runners/difysvapi.py @@ -148,7 +148,6 @@ class DifyServiceAPIRunner(runner.RequestRunner): if mode == 'workflow': if chunk['event'] == 'node_finished': if not is_stream: - if chunk['data']['node_type'] == 'answer': yield llm_entities.Message( role='assistant', @@ -274,7 +273,6 @@ class DifyServiceAPIRunner(runner.RequestRunner): content=self._try_convert_thinking(pending_agent_message), ) - if chunk['event'] == 'agent_thought': if chunk['tool'] != '' and chunk['observation'] != '': # 工具调用结果,跳过 continue diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index 6b4da90b..599b0b08 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -2,7 +2,6 @@ from __future__ import annotations import json import copy -from ssl import ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE import typing from .. import runner from ...core import entities as core_entities @@ -30,11 +29,14 @@ class LocalAgentRunner(runner.RequestRunner): class ToolCallTracker: """工具调用追踪器""" + def __init__(self): - self.active_calls: dict[str,dict] = {} + self.active_calls: dict[str, dict] = {} self.completed_calls: list[llm_entities.ToolCall] = [] - async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message | llm_entities.MessageChunk, None]: + async def run( + self, query: core_entities.Query + ) -> typing.AsyncGenerator[llm_entities.Message | llm_entities.MessageChunk, None]: """运行请求""" pending_tool_calls = [] @@ -89,16 +91,14 @@ class LocalAgentRunner(runner.RequestRunner): is_stream = query.adapter.is_stream_output_supported() try: - # print(await query.adapter.is_stream_output_supported()) is_stream = await query.adapter.is_stream_output_supported() except AttributeError: is_stream = False - # while True: - # pass + if not is_stream: # 非流式输出,直接请求 - # print(123) + msg = await query.use_llm_model.requester.invoke_llm( query, query.use_llm_model, @@ -108,7 +108,6 @@ class LocalAgentRunner(runner.RequestRunner): ) yield msg final_msg = msg - print(final_msg) else: # 流式输出,需要处理工具调用 tool_calls_map: dict[str, llm_entities.ToolCall] = {} @@ -122,27 +121,26 @@ class LocalAgentRunner(runner.RequestRunner): ): assert isinstance(msg, llm_entities.MessageChunk) yield msg - # if msg.tool_calls: - # for tool_call in msg.tool_calls: - # if tool_call.id not in tool_calls_map: - # tool_calls_map[tool_call.id] = llm_entities.ToolCall( - # id=tool_call.id, - # type=tool_call.type, - # function=llm_entities.FunctionCall( - # name=tool_call.function.name if tool_call.function else '', - # arguments='' - # ), - # ) - # if tool_call.function and tool_call.function.arguments: - # # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 - # tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments + if msg.tool_calls: + for tool_call in msg.tool_calls: + if tool_call.id not in tool_calls_map: + tool_calls_map[tool_call.id] = llm_entities.ToolCall( + id=tool_call.id, + type=tool_call.type, + function=llm_entities.FunctionCall( + name=tool_call.function.name if tool_call.function else '', + arguments='' + ), + ) + if tool_call.function and tool_call.function.arguments: + # 流式处理中,工具调用参数可能分多个chunk返回,需要追加而不是覆盖 + tool_calls_map[tool_call.id].function.arguments += tool_call.function.arguments final_msg = llm_entities.Message( role=msg.role, content=msg.all_content, tool_calls=list(tool_calls_map.values()), ) - pending_tool_calls = final_msg.tool_calls req_messages.append(final_msg) @@ -193,8 +191,7 @@ class LocalAgentRunner(runner.RequestRunner): id=tool_call.id, type=tool_call.type, function=llm_entities.FunctionCall( - name=tool_call.function.name if tool_call.function else '', - arguments='' + name=tool_call.function.name if tool_call.function else '', arguments='' ), ) if tool_call.function and tool_call.function.arguments: @@ -206,7 +203,6 @@ class LocalAgentRunner(runner.RequestRunner): tool_calls=list(tool_calls_map.values()), ) else: - print("非流式") # 处理完所有调用,再次请求 msg = await query.use_llm_model.requester.invoke_llm( query,