feat: combine kb with pipeline

This commit is contained in:
Junyan Qin
2025-07-17 23:15:13 +08:00
parent 45afdbdfbb
commit 27bb4e1253
5 changed files with 93 additions and 31 deletions

View File

@@ -20,7 +20,6 @@ class LegacyPipeline(Base):
)
for_version = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
is_default = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False)
knowledge_base_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
stages = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)

View File

@@ -0,0 +1,38 @@
from .. import migration
import sqlalchemy
from ...entity.persistence import pipeline as persistence_pipeline
@migration.migration_class(4)
class DBMigrateRAGKBUUID(migration.DBMigration):
"""RAG知识库UUID"""
async def upgrade(self):
"""升级"""
# read all pipelines
pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
for pipeline in pipelines:
serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
config = serialized_pipeline['config']
if 'knowledge-base' not in config['ai']['local-agent']:
config['ai']['local-agent']['knowledge-base'] = ''
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
.where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid'])
.values(
{
'config': config,
'for_version': self.ap.ver_mgr.get_current_version(),
}
)
)
async def downgrade(self):
"""降级"""
pass

View File

@@ -80,14 +80,15 @@ class PreProcessor(stage.PipelineStage):
if me.type == 'image_url':
msg.content.remove(me)
content_list = []
content_list: list[llm_entities.ContentElement] = []
plain_text = ''
qoute_msg = query.pipeline_config['trigger'].get('misc', '').get('combine-quote-message')
# tidy the content_list
# combine all text content into one, and put it in the first position
for me in query.message_chain:
if isinstance(me, platform_message.Plain):
content_list.append(llm_entities.ContentElement.from_text(me.text))
plain_text += me.text
elif isinstance(me, platform_message.Image):
if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__(
@@ -106,6 +107,8 @@ class PreProcessor(stage.PipelineStage):
if msg.base64 is not None:
content_list.append(llm_entities.ContentElement.from_image_base64(msg.base64))
content_list.insert(0, llm_entities.ContentElement.from_text(plain_text))
query.variables['user_message_text'] = plain_text
query.user_message = llm_entities.Message(role='user', content=content_list)

View File

@@ -1,13 +1,28 @@
from __future__ import annotations
import json
import copy
import typing
from ...platform.types import message as platform_entities
from .. import runner
from ...core import entities as core_entities
from .. import entities as llm_entities
rag_combined_prompt_template = """
The following are relevant context entries retrieved from the knowledge base.
Please use them to answer the user's message.
Respond in the same language as the user's input.
<context>
{rag_context}
</context>
<user_message>
{user_message}
</user_message>
"""
@runner.runner_class('local-agent')
class LocalAgentRunner(runner.RequestRunner):
"""本地Agent请求运行器"""
@@ -16,42 +31,49 @@ class LocalAgentRunner(runner.RequestRunner):
"""运行请求"""
pending_tool_calls = []
kb_uuid = query.pipeline_config['ai']['local-agent']['knowledge-base']
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
user_message = copy.deepcopy(query.user_message)
user_message_text = ''
pipeline_uuid = query.pipeline_uuid
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(pipeline_uuid)
if isinstance(user_message.content, str):
user_message_text = user_message.content
elif isinstance(user_message.content, list):
for ce in user_message.content:
if ce.type == 'text':
user_message_text += ce.text
break
try:
if pipeline and pipeline.pipeline_entity.knowledge_base_uuid is not None:
kb_id = pipeline.pipeline_entity.knowledge_base_uuid
kb= await self.ap.rag_mgr.load_knowledge_base(kb_id)
except Exception as e:
self.ap.logger.error(f'Failed to load knowledge base {kb_id}: {e}')
kb_id = None
if kb_uuid and user_message_text:
# only support text for now
kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
if kb:
message = ''
for msg in query.message_chain:
if isinstance(msg, platform_entities.Plain):
message += msg.text
result = await kb.retrieve(message)
if not kb:
self.ap.logger.warning(f'Knowledge base {kb_uuid} not found')
raise ValueError(f'Knowledge base {kb_uuid} not found')
result = await kb.retrieve(user_message_text)
final_user_message_text = ''
if result:
rag_context = "\n\n".join(
f"[{i+1}] {entry.metadata.get('text', '')}" for i, entry in enumerate(result)
rag_context = '\n\n'.join(
f'[{i + 1}] {entry.metadata.get("text", "")}' for i, entry in enumerate(result)
)
rag_message = llm_entities.Message(
role="user",
content="The following are relevant context entries retrieved from the knowledge base. "
"Please use them to answer the user's question. "
"Respond in the same language as the user's input.\n\n" + rag_context
final_user_message_text = rag_combined_prompt_template.format(
rag_context=rag_context, user_message=user_message_text
)
req_messages += [rag_message]
else:
final_user_message_text = user_message_text
for ce in user_message.content:
if ce.type == 'text':
ce.text = final_user_message_text
break
req_messages = query.prompt.messages.copy() + query.messages.copy() + [user_message]
# 首次请求
msg = await query.use_llm_model.requester.invoke_llm(

View File

@@ -1,6 +1,6 @@
semantic_version = 'v4.0.8'
required_database_version = 3
required_database_version = 4
"""Tag the version of the database schema, used to check if the database needs to be migrated"""
debug_mode = False