From 27bb4e1253149b31c4d7e74920aeb86fa86af254 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Thu, 17 Jul 2025 23:15:13 +0800 Subject: [PATCH] feat: combine kb with pipeline --- pkg/entity/persistence/pipeline.py | 1 - .../migrations/dbm004_rag_kb_uuid.py | 38 ++++++++++ pkg/pipeline/preproc/preproc.py | 7 +- pkg/provider/runners/localagent.py | 76 ++++++++++++------- pkg/utils/constants.py | 2 +- 5 files changed, 93 insertions(+), 31 deletions(-) create mode 100644 pkg/persistence/migrations/dbm004_rag_kb_uuid.py diff --git a/pkg/entity/persistence/pipeline.py b/pkg/entity/persistence/pipeline.py index a07470f3..70e76dab 100644 --- a/pkg/entity/persistence/pipeline.py +++ b/pkg/entity/persistence/pipeline.py @@ -20,7 +20,6 @@ class LegacyPipeline(Base): ) for_version = sqlalchemy.Column(sqlalchemy.String(255), nullable=False) is_default = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False) - knowledge_base_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=True) stages = sqlalchemy.Column(sqlalchemy.JSON, nullable=False) config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False) diff --git a/pkg/persistence/migrations/dbm004_rag_kb_uuid.py b/pkg/persistence/migrations/dbm004_rag_kb_uuid.py new file mode 100644 index 00000000..b45cfa78 --- /dev/null +++ b/pkg/persistence/migrations/dbm004_rag_kb_uuid.py @@ -0,0 +1,38 @@ +from .. import migration + +import sqlalchemy + +from ...entity.persistence import pipeline as persistence_pipeline + + +@migration.migration_class(4) +class DBMigrateRAGKBUUID(migration.DBMigration): + """RAG知识库UUID""" + + async def upgrade(self): + """升级""" + # read all pipelines + pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline)) + + for pipeline in pipelines: + serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline) + + config = serialized_pipeline['config'] + + if 'knowledge-base' not in config['ai']['local-agent']: + config['ai']['local-agent']['knowledge-base'] = '' + + await self.ap.persistence_mgr.execute_async( + sqlalchemy.update(persistence_pipeline.LegacyPipeline) + .where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid']) + .values( + { + 'config': config, + 'for_version': self.ap.ver_mgr.get_current_version(), + } + ) + ) + + async def downgrade(self): + """降级""" + pass diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py index 19478200..fd4c0bb6 100644 --- a/pkg/pipeline/preproc/preproc.py +++ b/pkg/pipeline/preproc/preproc.py @@ -80,14 +80,15 @@ class PreProcessor(stage.PipelineStage): if me.type == 'image_url': msg.content.remove(me) - content_list = [] + content_list: list[llm_entities.ContentElement] = [] plain_text = '' qoute_msg = query.pipeline_config['trigger'].get('misc', '').get('combine-quote-message') + # tidy the content_list + # combine all text content into one, and put it in the first position for me in query.message_chain: if isinstance(me, platform_message.Plain): - content_list.append(llm_entities.ContentElement.from_text(me.text)) plain_text += me.text elif isinstance(me, platform_message.Image): if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__( @@ -106,6 +107,8 @@ class PreProcessor(stage.PipelineStage): if msg.base64 is not None: content_list.append(llm_entities.ContentElement.from_image_base64(msg.base64)) + content_list.insert(0, llm_entities.ContentElement.from_text(plain_text)) + query.variables['user_message_text'] = plain_text query.user_message = llm_entities.Message(role='user', content=content_list) diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py index 16d61e1b..e7fa12c4 100644 --- a/pkg/provider/runners/localagent.py +++ b/pkg/provider/runners/localagent.py @@ -1,13 +1,28 @@ from __future__ import annotations import json +import copy import typing -from ...platform.types import message as platform_entities from .. import runner from ...core import entities as core_entities from .. import entities as llm_entities +rag_combined_prompt_template = """ +The following are relevant context entries retrieved from the knowledge base. +Please use them to answer the user's message. +Respond in the same language as the user's input. + + +{rag_context} + + + +{user_message} + +""" + + @runner.runner_class('local-agent') class LocalAgentRunner(runner.RequestRunner): """本地Agent请求运行器""" @@ -15,43 +30,50 @@ class LocalAgentRunner(runner.RequestRunner): async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]: """运行请求""" pending_tool_calls = [] - - req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message] + kb_uuid = query.pipeline_config['ai']['local-agent']['knowledge-base'] - - pipeline_uuid = query.pipeline_uuid - pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(pipeline_uuid) + user_message = copy.deepcopy(query.user_message) - try: - if pipeline and pipeline.pipeline_entity.knowledge_base_uuid is not None: - kb_id = pipeline.pipeline_entity.knowledge_base_uuid - kb= await self.ap.rag_mgr.load_knowledge_base(kb_id) - except Exception as e: - self.ap.logger.error(f'Failed to load knowledge base {kb_id}: {e}') - kb_id = None + user_message_text = '' - if kb: - message = '' - for msg in query.message_chain: - if isinstance(msg, platform_entities.Plain): - message += msg.text - result = await kb.retrieve(message) + if isinstance(user_message.content, str): + user_message_text = user_message.content + elif isinstance(user_message.content, list): + for ce in user_message.content: + if ce.type == 'text': + user_message_text += ce.text + break + + if kb_uuid and user_message_text: + # only support text for now + kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid) + + if not kb: + self.ap.logger.warning(f'Knowledge base {kb_uuid} not found') + raise ValueError(f'Knowledge base {kb_uuid} not found') + + result = await kb.retrieve(user_message_text) + + final_user_message_text = '' if result: - rag_context = "\n\n".join( - f"[{i+1}] {entry.metadata.get('text', '')}" for i, entry in enumerate(result) + rag_context = '\n\n'.join( + f'[{i + 1}] {entry.metadata.get("text", "")}' for i, entry in enumerate(result) ) - rag_message = llm_entities.Message( - role="user", - content="The following are relevant context entries retrieved from the knowledge base. " - "Please use them to answer the user's question. " - "Respond in the same language as the user's input.\n\n" + rag_context + final_user_message_text = rag_combined_prompt_template.format( + rag_context=rag_context, user_message=user_message_text ) - req_messages += [rag_message] + else: + final_user_message_text = user_message_text + for ce in user_message.content: + if ce.type == 'text': + ce.text = final_user_message_text + break + req_messages = query.prompt.messages.copy() + query.messages.copy() + [user_message] # 首次请求 msg = await query.use_llm_model.requester.invoke_llm( diff --git a/pkg/utils/constants.py b/pkg/utils/constants.py index e8193839..711ebf5d 100644 --- a/pkg/utils/constants.py +++ b/pkg/utils/constants.py @@ -1,6 +1,6 @@ semantic_version = 'v4.0.8' -required_database_version = 3 +required_database_version = 4 """Tag the version of the database schema, used to check if the database needs to be migrated""" debug_mode = False