diff --git a/pkg/entity/persistence/pipeline.py b/pkg/entity/persistence/pipeline.py
index a07470f3..70e76dab 100644
--- a/pkg/entity/persistence/pipeline.py
+++ b/pkg/entity/persistence/pipeline.py
@@ -20,7 +20,6 @@ class LegacyPipeline(Base):
)
for_version = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
is_default = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False)
- knowledge_base_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
stages = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
diff --git a/pkg/persistence/migrations/dbm004_rag_kb_uuid.py b/pkg/persistence/migrations/dbm004_rag_kb_uuid.py
new file mode 100644
index 00000000..b45cfa78
--- /dev/null
+++ b/pkg/persistence/migrations/dbm004_rag_kb_uuid.py
@@ -0,0 +1,38 @@
+from .. import migration
+
+import sqlalchemy
+
+from ...entity.persistence import pipeline as persistence_pipeline
+
+
+@migration.migration_class(4)
+class DBMigrateRAGKBUUID(migration.DBMigration):
+ """RAG知识库UUID"""
+
+ async def upgrade(self):
+ """升级"""
+ # read all pipelines
+ pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
+
+ for pipeline in pipelines:
+ serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
+
+ config = serialized_pipeline['config']
+
+ if 'knowledge-base' not in config['ai']['local-agent']:
+ config['ai']['local-agent']['knowledge-base'] = ''
+
+ await self.ap.persistence_mgr.execute_async(
+ sqlalchemy.update(persistence_pipeline.LegacyPipeline)
+ .where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid'])
+ .values(
+ {
+ 'config': config,
+ 'for_version': self.ap.ver_mgr.get_current_version(),
+ }
+ )
+ )
+
+ async def downgrade(self):
+ """降级"""
+ pass
diff --git a/pkg/pipeline/preproc/preproc.py b/pkg/pipeline/preproc/preproc.py
index 19478200..fd4c0bb6 100644
--- a/pkg/pipeline/preproc/preproc.py
+++ b/pkg/pipeline/preproc/preproc.py
@@ -80,14 +80,15 @@ class PreProcessor(stage.PipelineStage):
if me.type == 'image_url':
msg.content.remove(me)
- content_list = []
+ content_list: list[llm_entities.ContentElement] = []
plain_text = ''
qoute_msg = query.pipeline_config['trigger'].get('misc', '').get('combine-quote-message')
+ # tidy the content_list
+ # combine all text content into one, and put it in the first position
for me in query.message_chain:
if isinstance(me, platform_message.Plain):
- content_list.append(llm_entities.ContentElement.from_text(me.text))
plain_text += me.text
elif isinstance(me, platform_message.Image):
if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__(
@@ -106,6 +107,8 @@ class PreProcessor(stage.PipelineStage):
if msg.base64 is not None:
content_list.append(llm_entities.ContentElement.from_image_base64(msg.base64))
+ content_list.insert(0, llm_entities.ContentElement.from_text(plain_text))
+
query.variables['user_message_text'] = plain_text
query.user_message = llm_entities.Message(role='user', content=content_list)
diff --git a/pkg/provider/runners/localagent.py b/pkg/provider/runners/localagent.py
index 16d61e1b..e7fa12c4 100644
--- a/pkg/provider/runners/localagent.py
+++ b/pkg/provider/runners/localagent.py
@@ -1,13 +1,28 @@
from __future__ import annotations
import json
+import copy
import typing
-from ...platform.types import message as platform_entities
from .. import runner
from ...core import entities as core_entities
from .. import entities as llm_entities
+rag_combined_prompt_template = """
+The following are relevant context entries retrieved from the knowledge base.
+Please use them to answer the user's message.
+Respond in the same language as the user's input.
+
+
+{rag_context}
+
+
+
+{user_message}
+
+"""
+
+
@runner.runner_class('local-agent')
class LocalAgentRunner(runner.RequestRunner):
"""本地Agent请求运行器"""
@@ -15,43 +30,50 @@ class LocalAgentRunner(runner.RequestRunner):
async def run(self, query: core_entities.Query) -> typing.AsyncGenerator[llm_entities.Message, None]:
"""运行请求"""
pending_tool_calls = []
-
- req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
+ kb_uuid = query.pipeline_config['ai']['local-agent']['knowledge-base']
-
- pipeline_uuid = query.pipeline_uuid
- pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(pipeline_uuid)
+ user_message = copy.deepcopy(query.user_message)
- try:
- if pipeline and pipeline.pipeline_entity.knowledge_base_uuid is not None:
- kb_id = pipeline.pipeline_entity.knowledge_base_uuid
- kb= await self.ap.rag_mgr.load_knowledge_base(kb_id)
- except Exception as e:
- self.ap.logger.error(f'Failed to load knowledge base {kb_id}: {e}')
- kb_id = None
+ user_message_text = ''
- if kb:
- message = ''
- for msg in query.message_chain:
- if isinstance(msg, platform_entities.Plain):
- message += msg.text
- result = await kb.retrieve(message)
+ if isinstance(user_message.content, str):
+ user_message_text = user_message.content
+ elif isinstance(user_message.content, list):
+ for ce in user_message.content:
+ if ce.type == 'text':
+ user_message_text += ce.text
+ break
+
+ if kb_uuid and user_message_text:
+ # only support text for now
+ kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
+
+ if not kb:
+ self.ap.logger.warning(f'Knowledge base {kb_uuid} not found')
+ raise ValueError(f'Knowledge base {kb_uuid} not found')
+
+ result = await kb.retrieve(user_message_text)
+
+ final_user_message_text = ''
if result:
- rag_context = "\n\n".join(
- f"[{i+1}] {entry.metadata.get('text', '')}" for i, entry in enumerate(result)
+ rag_context = '\n\n'.join(
+ f'[{i + 1}] {entry.metadata.get("text", "")}' for i, entry in enumerate(result)
)
- rag_message = llm_entities.Message(
- role="user",
- content="The following are relevant context entries retrieved from the knowledge base. "
- "Please use them to answer the user's question. "
- "Respond in the same language as the user's input.\n\n" + rag_context
+ final_user_message_text = rag_combined_prompt_template.format(
+ rag_context=rag_context, user_message=user_message_text
)
- req_messages += [rag_message]
+ else:
+ final_user_message_text = user_message_text
+ for ce in user_message.content:
+ if ce.type == 'text':
+ ce.text = final_user_message_text
+ break
+ req_messages = query.prompt.messages.copy() + query.messages.copy() + [user_message]
# 首次请求
msg = await query.use_llm_model.requester.invoke_llm(
diff --git a/pkg/utils/constants.py b/pkg/utils/constants.py
index e8193839..711ebf5d 100644
--- a/pkg/utils/constants.py
+++ b/pkg/utils/constants.py
@@ -1,6 +1,6 @@
semantic_version = 'v4.0.8'
-required_database_version = 3
+required_database_version = 4
"""Tag the version of the database schema, used to check if the database needs to be migrated"""
debug_mode = False