mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 11:29:39 +08:00
feat: combine kb with pipeline
This commit is contained in:
@@ -20,7 +20,6 @@ class LegacyPipeline(Base):
|
||||
)
|
||||
for_version = sqlalchemy.Column(sqlalchemy.String(255), nullable=False)
|
||||
is_default = sqlalchemy.Column(sqlalchemy.Boolean, nullable=False, default=False)
|
||||
knowledge_base_uuid = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
stages = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
|
||||
config = sqlalchemy.Column(sqlalchemy.JSON, nullable=False)
|
||||
|
||||
|
||||
38
pkg/persistence/migrations/dbm004_rag_kb_uuid.py
Normal file
38
pkg/persistence/migrations/dbm004_rag_kb_uuid.py
Normal file
@@ -0,0 +1,38 @@
|
||||
from .. import migration
|
||||
|
||||
import sqlalchemy
|
||||
|
||||
from ...entity.persistence import pipeline as persistence_pipeline
|
||||
|
||||
|
||||
@migration.migration_class(4)
|
||||
class DBMigrateRAGKBUUID(migration.DBMigration):
|
||||
"""RAG知识库UUID"""
|
||||
|
||||
async def upgrade(self):
|
||||
"""升级"""
|
||||
# read all pipelines
|
||||
pipelines = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_pipeline.LegacyPipeline))
|
||||
|
||||
for pipeline in pipelines:
|
||||
serialized_pipeline = self.ap.persistence_mgr.serialize_model(persistence_pipeline.LegacyPipeline, pipeline)
|
||||
|
||||
config = serialized_pipeline['config']
|
||||
|
||||
if 'knowledge-base' not in config['ai']['local-agent']:
|
||||
config['ai']['local-agent']['knowledge-base'] = ''
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_pipeline.LegacyPipeline)
|
||||
.where(persistence_pipeline.LegacyPipeline.uuid == serialized_pipeline['uuid'])
|
||||
.values(
|
||||
{
|
||||
'config': config,
|
||||
'for_version': self.ap.ver_mgr.get_current_version(),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
async def downgrade(self):
|
||||
"""降级"""
|
||||
pass
|
||||
@@ -80,14 +80,15 @@ class PreProcessor(stage.PipelineStage):
|
||||
if me.type == 'image_url':
|
||||
msg.content.remove(me)
|
||||
|
||||
content_list = []
|
||||
content_list: list[llm_entities.ContentElement] = []
|
||||
|
||||
plain_text = ''
|
||||
qoute_msg = query.pipeline_config['trigger'].get('misc', '').get('combine-quote-message')
|
||||
|
||||
# tidy the content_list
|
||||
# combine all text content into one, and put it in the first position
|
||||
for me in query.message_chain:
|
||||
if isinstance(me, platform_message.Plain):
|
||||
content_list.append(llm_entities.ContentElement.from_text(me.text))
|
||||
plain_text += me.text
|
||||
elif isinstance(me, platform_message.Image):
|
||||
if selected_runner != 'local-agent' or query.use_llm_model.model_entity.abilities.__contains__(
|
||||
@@ -106,6 +107,8 @@ class PreProcessor(stage.PipelineStage):
|
||||
if msg.base64 is not None:
|
||||
content_list.append(llm_entities.ContentElement.from_image_base64(msg.base64))
|
||||
|
||||
content_list.insert(0, llm_entities.ContentElement.from_text(plain_text))
|
||||
|
||||
query.variables['user_message_text'] = plain_text
|
||||
|
||||
query.user_message = llm_entities.Message(role='user', content=content_list)
|
||||
|
||||
@@ -1,13 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import copy
|
||||
import typing
|
||||
from ...platform.types import message as platform_entities
|
||||
from .. import runner
|
||||
from ...core import entities as core_entities
|
||||
from .. import entities as llm_entities
|
||||
|
||||
|
||||
rag_combined_prompt_template = """
|
||||
The following are relevant context entries retrieved from the knowledge base.
|
||||
Please use them to answer the user's message.
|
||||
Respond in the same language as the user's input.
|
||||
|
||||
<context>
|
||||
{rag_context}
|
||||
</context>
|
||||
|
||||
<user_message>
|
||||
{user_message}
|
||||
</user_message>
|
||||
"""
|
||||
|
||||
|
||||
@runner.runner_class('local-agent')
|
||||
class LocalAgentRunner(runner.RequestRunner):
|
||||
"""本地Agent请求运行器"""
|
||||
@@ -16,42 +31,49 @@ class LocalAgentRunner(runner.RequestRunner):
|
||||
"""运行请求"""
|
||||
pending_tool_calls = []
|
||||
|
||||
kb_uuid = query.pipeline_config['ai']['local-agent']['knowledge-base']
|
||||
|
||||
req_messages = query.prompt.messages.copy() + query.messages.copy() + [query.user_message]
|
||||
user_message = copy.deepcopy(query.user_message)
|
||||
|
||||
user_message_text = ''
|
||||
|
||||
pipeline_uuid = query.pipeline_uuid
|
||||
pipeline = await self.ap.pipeline_mgr.get_pipeline_by_uuid(pipeline_uuid)
|
||||
if isinstance(user_message.content, str):
|
||||
user_message_text = user_message.content
|
||||
elif isinstance(user_message.content, list):
|
||||
for ce in user_message.content:
|
||||
if ce.type == 'text':
|
||||
user_message_text += ce.text
|
||||
break
|
||||
|
||||
try:
|
||||
if pipeline and pipeline.pipeline_entity.knowledge_base_uuid is not None:
|
||||
kb_id = pipeline.pipeline_entity.knowledge_base_uuid
|
||||
kb= await self.ap.rag_mgr.load_knowledge_base(kb_id)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Failed to load knowledge base {kb_id}: {e}')
|
||||
kb_id = None
|
||||
if kb_uuid and user_message_text:
|
||||
# only support text for now
|
||||
kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||
|
||||
if kb:
|
||||
message = ''
|
||||
for msg in query.message_chain:
|
||||
if isinstance(msg, platform_entities.Plain):
|
||||
message += msg.text
|
||||
result = await kb.retrieve(message)
|
||||
if not kb:
|
||||
self.ap.logger.warning(f'Knowledge base {kb_uuid} not found')
|
||||
raise ValueError(f'Knowledge base {kb_uuid} not found')
|
||||
|
||||
result = await kb.retrieve(user_message_text)
|
||||
|
||||
final_user_message_text = ''
|
||||
|
||||
if result:
|
||||
rag_context = "\n\n".join(
|
||||
f"[{i+1}] {entry.metadata.get('text', '')}" for i, entry in enumerate(result)
|
||||
rag_context = '\n\n'.join(
|
||||
f'[{i + 1}] {entry.metadata.get("text", "")}' for i, entry in enumerate(result)
|
||||
)
|
||||
rag_message = llm_entities.Message(
|
||||
role="user",
|
||||
content="The following are relevant context entries retrieved from the knowledge base. "
|
||||
"Please use them to answer the user's question. "
|
||||
"Respond in the same language as the user's input.\n\n" + rag_context
|
||||
final_user_message_text = rag_combined_prompt_template.format(
|
||||
rag_context=rag_context, user_message=user_message_text
|
||||
)
|
||||
req_messages += [rag_message]
|
||||
|
||||
else:
|
||||
final_user_message_text = user_message_text
|
||||
|
||||
for ce in user_message.content:
|
||||
if ce.type == 'text':
|
||||
ce.text = final_user_message_text
|
||||
break
|
||||
|
||||
req_messages = query.prompt.messages.copy() + query.messages.copy() + [user_message]
|
||||
|
||||
# 首次请求
|
||||
msg = await query.use_llm_model.requester.invoke_llm(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
semantic_version = 'v4.0.8'
|
||||
|
||||
required_database_version = 3
|
||||
required_database_version = 4
|
||||
"""Tag the version of the database schema, used to check if the database needs to be migrated"""
|
||||
|
||||
debug_mode = False
|
||||
|
||||
Reference in New Issue
Block a user