fix: in the dify non-stream remove_think lgic

This commit is contained in:
Dong_master
2025-08-14 22:32:22 +08:00
parent 13dd6fcee3
commit b8b9a37825

View File

@@ -62,6 +62,39 @@ class DifyServiceAPIRunner(runner.RequestRunner):
content_text = re.sub(pattern, '', resp_text, flags=re.DOTALL)
return f'<think>{thinking_text.group(1)}</think>\n{content_text}'
def _process_thinking_content(
self,
content: str,
) -> tuple[str, str]:
"""处理思维链内容
Args:
content: 原始内容
Returns:
(处理后的内容, 提取的思维链内容)
"""
remove_think = self.pipeline_config['output'].get('misc', '').get('remove-think')
thinking_content = ''
# 从 content 中提取 <think> 标签内容
if content and '<think>' in content and '</think>' in content:
import re
think_pattern = r'<think>(.*?)</think>'
think_matches = re.findall(think_pattern, content, re.DOTALL)
if think_matches:
thinking_content = '\n'.join(think_matches)
# 移除 content 中的 <think> 标签
content = re.sub(think_pattern, '', content, flags=re.DOTALL).strip()
# 3. 根据 remove_think 参数决定是否保留思维链
if remove_think:
return content, ''
else:
# 如果有思维链内容,将其以 <think> 格式添加到 content 开头
if thinking_content:
content = f'<think>\n{thinking_content}\n</think>\n{content}'.strip()
return content, thinking_content
async def _preprocess_user_message(self, query: core_entities.Query) -> tuple[str, list[str]]:
"""预处理用户消息,提取纯文本,并将图片上传到 Dify 服务
@@ -95,11 +128,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
cov_id = query.session.using_conversation.uuid or ''
query.variables['conversation_id'] = cov_id
try:
is_stream = await query.adapter.is_stream_output_supported()
except AttributeError:
is_stream = False
plain_text, image_ids = await self._preprocess_user_message(query)
files = [
@@ -113,11 +141,7 @@ class DifyServiceAPIRunner(runner.RequestRunner):
mode = 'basic' # 标记是基础编排还是工作流编排
stream_output_pending_chunk = ''
batch_pending_max_size = 8 # 积累一定量的消息更新消息一次
batch_pending_index = 0
basic_mode_pending_chunk = ''
inputs = {}
@@ -135,65 +159,28 @@ class DifyServiceAPIRunner(runner.RequestRunner):
):
self.ap.logger.debug('dify-chat-chunk: ' + str(chunk))
# 查询异常情况
if chunk['event'] == 'error':
yield llm_entities.Message(
role='assistant',
content=f"查询异常: [{chunk['code']}]. {chunk['message']}.\n请重试,如果还报错,请用 <font color='red'>**!reset**</font> 命令重置对话再尝试。",
)
if chunk['event'] == 'workflow_started':
mode = 'workflow'
if mode == 'workflow':
if chunk['event'] == 'node_finished':
if not is_stream:
if chunk['data']['node_type'] == 'answer':
yield llm_entities.Message(
role='assistant',
content=self._try_convert_thinking(chunk['data']['outputs']['answer']),
)
else:
if chunk['data']['node_type'] == 'answer':
yield llm_entities.MessageChunk(
role='assistant',
content=self._try_convert_thinking(chunk['data']['outputs']['answer']),
is_final=True,
)
elif chunk['event'] == 'message':
stream_output_pending_chunk += chunk['answer']
if is_stream:
# 消息数超过量就输出从而达到streaming的效果
batch_pending_index += 1
if batch_pending_index >= batch_pending_max_size:
yield llm_entities.MessageChunk(
role='assistant',
content=self._try_convert_thinking(stream_output_pending_chunk),
)
batch_pending_index = 0
if chunk['data']['node_type'] == 'answer':
content, _ = self._process_thinking_content(chunk['data']['outputs']['answer'])
yield llm_entities.Message(
role='assistant',
content=content,
)
elif mode == 'basic':
if chunk['event'] == 'message' or chunk['event'] == 'message_end':
if chunk['event'] == 'message_end':
is_final = True
if is_stream and batch_pending_index % batch_pending_max_size == 0:
# 消息数超过量就输出从而达到streaming的效果
batch_pending_index += 1
# if batch_pending_index >= batch_pending_max_size:
yield llm_entities.MessageChunk(
role='assistant',
content=self._try_convert_thinking(stream_output_pending_chunk),
is_final=is_final,
)
# batch_pending_index = 0
elif not is_stream:
yield llm_entities.Message(
role='assistant',
content=self._try_convert_thinking(stream_output_pending_chunk),
)
stream_output_pending_chunk = ''
else:
stream_output_pending_chunk += chunk['answer']
is_final = False
if chunk['event'] == 'message':
basic_mode_pending_chunk += chunk['answer']
elif chunk['event'] == 'message_end':
content, _ = self._process_thinking_content(basic_mode_pending_chunk)
yield llm_entities.Message(
role='assistant',
content=content,
)
basic_mode_pending_chunk = ''
if chunk is None:
raise errors.DifyAPIError('Dify API 没有返回任何响应请检查网络连接和API配置')
@@ -207,13 +194,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
cov_id = query.session.using_conversation.uuid or ''
query.variables['conversation_id'] = cov_id
try:
is_stream = await query.adapter.is_stream_output_supported()
except AttributeError:
is_stream = False
batch_pending_index = 0
plain_text, image_ids = await self._preprocess_user_message(query)
files = [
@@ -248,66 +228,39 @@ class DifyServiceAPIRunner(runner.RequestRunner):
if chunk['event'] in ignored_events:
continue
batch_pending_index += 1
if chunk['event'] == 'agent_message' or chunk['event'] == 'message_end':
if chunk['event'] == 'message_end':
# break
is_final = True
else:
is_final = False
pending_agent_message += chunk['answer']
if is_stream:
if batch_pending_index % 64 == 0 or is_final:
yield llm_entities.MessageChunk(
role='assistant',
content=self._try_convert_thinking(pending_agent_message),
is_final=is_final,
)
if chunk['event'] == 'agent_message' or chunk['event'] == 'message':
pending_agent_message += chunk['answer']
else:
if pending_agent_message.strip() != '' and not is_stream:
if pending_agent_message.strip() != '':
pending_agent_message = pending_agent_message.replace('</details>Action:', '</details>')
content, _ = self._process_thinking_content(pending_agent_message)
yield llm_entities.Message(
role='assistant',
content=self._try_convert_thinking(pending_agent_message),
content=content,
)
pending_agent_message = ''
if chunk['event'] == 'agent_thought':
if chunk['tool'] != '' and chunk['observation'] != '': # 工具调用结果,跳过
continue
if chunk['tool']:
if is_stream:
msg = llm_entities.MessageChunk(
role='assistant',
tool_calls=[
llm_entities.ToolCall(
id=chunk['id'],
type='function',
function=llm_entities.FunctionCall(
name=chunk['tool'],
arguments=json.dumps({}),
),
)
],
)
else:
msg = llm_entities.Message(
role='assistant',
tool_calls=[
llm_entities.ToolCall(
id=chunk['id'],
type='function',
function=llm_entities.FunctionCall(
name=chunk['tool'],
arguments=json.dumps({}),
),
)
],
)
msg = llm_entities.Message(
role='assistant',
tool_calls=[
llm_entities.ToolCall(
id=chunk['id'],
type='function',
function=llm_entities.FunctionCall(
name=chunk['tool'],
arguments=json.dumps({}),
),
)
],
)
yield msg
elif chunk['event'] == 'message_file':
if chunk['event'] == 'message_file':
if chunk['type'] == 'image' and chunk['belongs_to'] == 'assistant':
base_url = self.dify_client.base_url
@@ -315,20 +268,13 @@ class DifyServiceAPIRunner(runner.RequestRunner):
base_url = base_url[:-3]
image_url = base_url + chunk['url']
if is_stream:
yield llm_entities.MessageChunk(
role='assistant',
content=[llm_entities.ContentElement.from_image_url(image_url)],
)
else:
yield llm_entities.Message(
role='assistant',
content=[llm_entities.ContentElement.from_image_url(image_url)],
)
elif chunk['event'] == 'error':
yield llm_entities.Message(
role='assistant',
content=[llm_entities.ContentElement.from_image_url(image_url)],
)
if chunk['event'] == 'error':
raise errors.DifyAPIError('dify 服务错误: ' + chunk['message'])
else:
pending_agent_message = ''
if chunk is None:
raise errors.DifyAPIError('Dify API 没有返回任何响应请检查网络连接和API配置')
@@ -343,15 +289,6 @@ class DifyServiceAPIRunner(runner.RequestRunner):
query.variables['conversation_id'] = query.session.using_conversation.uuid
try:
is_stream = await query.adapter.is_stream_output_supported()
except AttributeError:
is_stream = False
_ = is_stream
# batch_pending_index = 0
plain_text, image_ids = await self._preprocess_user_message(query)
files = [
@@ -408,10 +345,11 @@ class DifyServiceAPIRunner(runner.RequestRunner):
elif chunk['event'] == 'workflow_finished':
if chunk['data']['error']:
raise errors.DifyAPIError(chunk['data']['error'])
content, _ = self._process_thinking_content(chunk['data']['outputs']['summary'])
msg = llm_entities.Message(
role='assistant',
content=chunk['data']['outputs']['summary'],
content=content,
)
yield msg
@@ -430,4 +368,4 @@ class DifyServiceAPIRunner(runner.RequestRunner):
else:
raise errors.DifyAPIError(
f'不支持的 Dify 应用类型: {self.pipeline_config["ai"]["dify-service-api"]["app-type"]}'
)
)