Files
LangBot/pkg/pipeline/preproc/preproc.py
2025-07-13 18:41:04 +08:00

123 lines
4.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import datetime
from .. import stage, entities
from langbot_plugin.api.entities.builtin.provider import message as provider_message
import langbot_plugin.api.entities.events as events
import langbot_plugin.api.entities.builtin.platform.message as platform_message
import langbot_plugin.api.entities.builtin.pipeline.query as pipeline_query
@stage.stage_class('PreProcessor')
class PreProcessor(stage.PipelineStage):
"""请求预处理阶段
签出会话、prompt、上文、模型、内容函数。
改写:
- session
- prompt
- messages
- user_message
- use_model
- use_funcs
"""
async def process(
self,
query: pipeline_query.Query,
stage_inst_name: str,
) -> entities.StageProcessResult:
"""处理"""
selected_runner = query.pipeline_config['ai']['runner']['runner']
session = await self.ap.sess_mgr.get_session(query)
# 非 local-agent 时llm_model 为 None
llm_model = (
await self.ap.model_mgr.get_model_by_uuid(query.pipeline_config['ai']['local-agent']['model'])
if selected_runner == 'local-agent'
else None
)
conversation = await self.ap.sess_mgr.get_conversation(
query,
session,
query.pipeline_config['ai']['local-agent']['prompt'],
query.pipeline_uuid,
query.bot_uuid,
)
# 设置query
query.session = session
query.prompt = conversation.prompt.copy()
query.messages = conversation.messages.copy()
query.use_llm_model_uuid = llm_model.model_entity.uuid
if selected_runner == 'local-agent':
query.use_funcs = []
if llm_model.model_entity.abilities.__contains__('func_call'):
query.use_funcs = await self.ap.tool_mgr.get_all_tools()
variables = {
'session_id': f'{query.session.launcher_type.value}_{query.session.launcher_id}',
'conversation_id': conversation.uuid,
'msg_create_time': (
int(query.message_event.time) if query.message_event.time else int(datetime.datetime.now().timestamp())
),
}
query.variables.update(variables)
# Check if this model supports vision, if not, remove all images
# TODO this checking should be performed in runner, and in this stage, the image should be reserved
if selected_runner == 'local-agent' and not llm_model.model_entity.abilities.__contains__('vision'):
for msg in query.messages:
if isinstance(msg.content, list):
for me in msg.content:
if me.type == 'image_url':
msg.content.remove(me)
content_list = []
plain_text = ''
qoute_msg = query.pipeline_config['trigger'].get('misc', '').get('combine-quote-message')
for me in query.message_chain:
if isinstance(me, platform_message.Plain):
content_list.append(provider_message.ContentElement.from_text(me.text))
plain_text += me.text
elif isinstance(me, platform_message.Image):
if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
if me.base64 is not None:
content_list.append(provider_message.ContentElement.from_image_base64(me.base64))
elif isinstance(me, platform_message.Quote) and qoute_msg:
for msg in me.origin:
if isinstance(msg, platform_message.Plain):
content_list.append(provider_message.ContentElement.from_text(msg.text))
elif isinstance(msg, platform_message.Image):
if selected_runner != 'local-agent' or llm_model.model_entity.abilities.__contains__('vision'):
if msg.base64 is not None:
content_list.append(provider_message.ContentElement.from_image_base64(msg.base64))
query.variables['user_message_text'] = plain_text
query.user_message = provider_message.Message(role='user', content=content_list)
# =========== 触发事件 PromptPreProcessing
event = events.PromptPreProcessing(
session_name=f'{query.session.launcher_type.value}_{query.session.launcher_id}',
default_prompt=query.prompt.messages,
prompt=query.messages,
query=query,
)
event_ctx = await self.ap.plugin_connector.emit_event(event)
query.prompt.messages = event_ctx.event.default_prompt
query.messages = event_ctx.event.prompt
return entities.StageProcessResult(result_type=entities.ResultType.CONTINUE, new_query=query)