Files
LangBot/pkg/pipeline/process/handlers/chat.py

178 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import typing
import time
import traceback
import json
import mirai
from .. import handler
from ... import entities
from ....core import entities as core_entities
from ....provider import entities as llm_entities
from ....plugin import events
class ChatMessageHandler(handler.MessageHandler):
async def handle(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[entities.StageProcessResult, None]:
"""处理
"""
# 调API
# 生成器
# 触发插件事件
event_class = events.PersonNormalMessageReceived if query.launcher_type == core_entities.LauncherTypes.PERSON else events.GroupNormalMessageReceived
event_ctx = await self.ap.plugin_mgr.emit_event(
event=event_class(
launcher_type=query.launcher_type.value,
launcher_id=query.launcher_id,
sender_id=query.sender_id,
text_message=str(query.message_chain),
query=query
)
)
if event_ctx.is_prevented_default():
if event_ctx.event.reply is not None:
mc = mirai.MessageChain(event_ctx.event.reply)
query.resp_messages.append(
llm_entities.Message(
role='plugin',
content=mc,
)
)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
else:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query
)
else:
if not self.ap.provider_cfg.data['enable-chat']:
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
)
if event_ctx.event.alter is not None:
query.message_chain = mirai.MessageChain([
mirai.Plain(event_ctx.event.alter)
])
text_length = 0
start_time = time.time()
try:
async for result in self.runner(query):
query.resp_messages.append(result)
self.ap.logger.info(f'对话({query.query_id})响应: {self.cut_str(result.readable_str())}')
if result.content is not None:
text_length += len(result.content)
yield entities.StageProcessResult(
result_type=entities.ResultType.CONTINUE,
new_query=query
)
query.session.using_conversation.messages.append(query.user_message)
query.session.using_conversation.messages.extend(query.resp_messages)
except Exception as e:
self.ap.logger.error(f'对话({query.query_id})请求失败: {str(e)}')
yield entities.StageProcessResult(
result_type=entities.ResultType.INTERRUPT,
new_query=query,
user_notice='请求失败' if self.ap.platform_cfg.data['hide-exception-info'] else f'{e}',
error_notice=f'{e}',
debug_notice=traceback.format_exc()
)
finally:
await self.ap.ctr_mgr.usage.post_query_record(
session_type=query.session.launcher_type.value,
session_id=str(query.session.launcher_id),
query_ability_provider="QChatGPT.Chat",
usage=text_length,
model_name=query.use_model.name,
response_seconds=int(time.time() - start_time),
retry_times=-1,
)
async def runner(
self,
query: core_entities.Query,
) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""执行一个请求处理过程中的LLM接口请求、函数调用的循环
这是临时处理方案后续可能改为使用LangChain或者自研的工作流处理器
"""
await query.use_model.requester.preprocess(query)
pending_tool_calls = []
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
# 首次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg)
# 持续请求,只要还有待处理的工具调用就继续处理调用
while pending_tool_calls:
for tool_call in pending_tool_calls:
try:
func = tool_call.function
parameters = json.loads(func.arguments)
func_ret = await self.ap.tool_mgr.execute_func_call(
query, func.name, parameters
)
msg = llm_entities.Message(
role="tool", content=json.dumps(func_ret, ensure_ascii=False), tool_call_id=tool_call.id
)
yield msg
req_messages.append(msg)
except Exception as e:
# 工具调用出错,添加一个报错信息到 req_messages
err_msg = llm_entities.Message(
role="tool", content=f"err: {e}", tool_call_id=tool_call.id
)
yield err_msg
req_messages.append(err_msg)
# 处理完所有调用,再次请求
msg = await query.use_model.requester.call(query.use_model, req_messages, query.use_funcs)
yield msg
pending_tool_calls = msg.tool_calls
req_messages.append(msg)