Merge pull request #785 from RockChinQ/fix/msg-chain-compability

Fix: 修复 query.resp_messages 对插件reply的兼容性
This commit is contained in:
Junyan Qin
2024-05-18 20:13:50 +08:00
committed by GitHub
6 changed files with 89 additions and 79 deletions

View File

@@ -67,7 +67,7 @@ class Query(pydantic.BaseModel):
use_funcs: typing.Optional[list[tools_entities.LLMFunction]] = None
"""使用的函数,由前置处理器阶段设置"""
resp_messages: typing.Optional[list[llm_entities.Message]] = []
resp_messages: typing.Optional[list[llm_entities.Message]] | typing.Optional[list[mirai.MessageChain]] = []
"""由Process阶段生成的回复消息对象列表"""
resp_message_chain: typing.Optional[list[mirai.MessageChain]] = None

View File

@@ -163,13 +163,13 @@ class ContentFilterStage(stage.PipelineStage):
)
elif stage_inst_name == 'PostContentFilterStage':
# 仅处理 query.resp_messages[-1].content 是 str 的情况
if isinstance(query.resp_messages[-1].content, str):
if isinstance(query.resp_messages[-1], llm_entities.Message) and isinstance(query.resp_messages[-1].content, str):
return await self._post_process(
query.resp_messages[-1].content,
query
)
else:
self.ap.logger.debug(f"resp_messages[-1] 不是 str 类型,跳过内容过滤器检查。")
self.ap.logger.debug(f"resp_messages[-1] 不是 Message 类型或 query.resp_messages[-1].content 不是 str 类型,跳过内容过滤器检查。")
return entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query

View File

@@ -42,12 +42,7 @@ class ChatMessageHandler(handler.MessageHandler):
if event_ctx.event.reply is not None:
mc = mirai.MessageChain(event_ctx.event.reply)
query.resp_messages.append(
llm_entities.Message(
role='plugin',
content=mc,
)
)
query.resp_messages.append(mc)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,

View File

@@ -48,12 +48,7 @@ class CommandHandler(handler.MessageHandler):
if event_ctx.event.reply is not None:
mc = mirai.MessageChain(event_ctx.event.reply)
query.resp_messages.append(
llm_entities.Message(
role='command',
content=str(mc),
)
)
query.resp_messages.append(mc)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,

View File

@@ -33,6 +33,17 @@ class ResponseWrapper(stage.PipelineStage):
"""处理
"""
# 如果 resp_messages[-1] 已经是 MessageChain 了
if isinstance(query.resp_messages[-1], mirai.MessageChain):
query.resp_message_chain.append(query.resp_messages[-1])
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
if query.resp_messages[-1].role == 'command':
# query.resp_message_chain.append(mirai.MessageChain("[bot] "+query.resp_messages[-1].content))
query.resp_message_chain.append(query.resp_messages[-1].get_content_mirai_message_chain(prefix_text='[bot] '))

View File

@@ -96,7 +96,16 @@ class Message(pydantic.BaseModel):
if ce.type == 'text':
mc.append(mirai.Plain(ce.text))
elif ce.type == 'image':
mc.append(mirai.Image(url=ce.image_url))
if ce.image_url.url.startswith("http"):
mc.append(mirai.Image(url=ce.image_url.url))
else: # base64
b64_str = ce.image_url.url
if b64_str.startswith("data:"):
b64_str = b64_str.split(",")[1]
mc.append(mirai.Image(base64=b64_str))
# 找第一个文字组件
if prefix_text: