mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 19:37:36 +08:00
feat(rag): make embedding and retrieving available
This commit is contained in:
@@ -86,10 +86,10 @@ class RouterGroup(abc.ABC):
|
||||
|
||||
try:
|
||||
return await f(*args, **kwargs)
|
||||
except Exception: # 自动 500
|
||||
except Exception as e: # 自动 500
|
||||
traceback.print_exc()
|
||||
# return self.http_status(500, -2, str(e))
|
||||
return self.http_status(500, -2, 'internal server error')
|
||||
return self.http_status(500, -2, str(e))
|
||||
|
||||
new_f = handler_error
|
||||
new_f.__name__ = (self.name + rule).replace('/', '__')
|
||||
|
||||
@@ -34,8 +34,9 @@ class FilesRouterGroup(group.RouterGroup):
|
||||
|
||||
file_bytes = await asyncio.to_thread(file.stream.read)
|
||||
extension = file.filename.split('.')[-1]
|
||||
file_name = file.filename.split('.')[0]
|
||||
|
||||
file_key = str(uuid.uuid4()) + '.' + extension
|
||||
file_key = file_name + '_' + str(uuid.uuid4())[:8] + '.' + extension
|
||||
# save file to storage
|
||||
await self.ap.storage_mgr.storage_provider.save(file_key, file_bytes)
|
||||
return self.success(
|
||||
|
||||
@@ -72,3 +72,13 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
async def delete_specific_file_in_kb(file_id: str, knowledge_base_uuid: str) -> str:
|
||||
await self.ap.knowledge_service.delete_file(knowledge_base_uuid, file_id)
|
||||
return self.success({})
|
||||
|
||||
@self.route(
|
||||
'/<knowledge_base_uuid>/retrieve',
|
||||
methods=['POST'],
|
||||
)
|
||||
async def retrieve_knowledge_base(knowledge_base_uuid: str) -> str:
|
||||
json_data = await quart.request.json
|
||||
query = json_data.get('query')
|
||||
results = await self.ap.knowledge_service.retrieve_knowledge_base(knowledge_base_uuid, query)
|
||||
return self.success(data={'results': results})
|
||||
|
||||
@@ -67,6 +67,13 @@ class KnowledgeService:
|
||||
raise Exception('Knowledge base not found')
|
||||
return await runtime_kb.store_file(file_id)
|
||||
|
||||
async def retrieve_knowledge_base(self, kb_uuid: str, query: str) -> list[dict]:
|
||||
"""检索知识库"""
|
||||
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||
if runtime_kb is None:
|
||||
raise Exception('Knowledge base not found')
|
||||
return [result.model_dump() for result in await runtime_kb.retrieve(query)]
|
||||
|
||||
async def get_files_by_knowledge_base(self, kb_uuid: str) -> list[dict]:
|
||||
"""获取知识库文件"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
|
||||
@@ -186,6 +186,6 @@ class EmbeddingModelsService:
|
||||
|
||||
await runtime_embedding_model.requester.invoke_embedding(
|
||||
model=runtime_embedding_model,
|
||||
input_text='Hello, world!',
|
||||
input_text=['Hello, world!'],
|
||||
extra_args={},
|
||||
)
|
||||
|
||||
@@ -28,7 +28,7 @@ from ..storage import mgr as storagemgr
|
||||
from ..utils import logcache
|
||||
from . import taskmgr
|
||||
from . import entities as core_entities
|
||||
from ..rag.knowledge import mgr as rag_mgr
|
||||
from ..rag.knowledge import kbmgr as rag_mgr
|
||||
from ..vector import mgr as vectordb_mgr
|
||||
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from ...command import cmdmgr
|
||||
from ...provider.session import sessionmgr as llm_session_mgr
|
||||
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ...provider.tools import toolmgr as llm_tool_mgr
|
||||
from ...rag.knowledge import mgr as rag_mgr
|
||||
from ...rag.knowledge import kbmgr as rag_mgr
|
||||
from ...platform import botmgr as im_mgr
|
||||
from ...persistence import mgr as persistencemgr
|
||||
from ...api.http.controller import main as http_controller
|
||||
@@ -92,7 +92,7 @@ class BuildAppStage(stage.BootingStage):
|
||||
ap.pipeline_mgr = pipeline_mgr
|
||||
|
||||
rag_mgr_inst = rag_mgr.RAGManager(ap)
|
||||
await rag_mgr_inst.initialize_rag_system()
|
||||
await rag_mgr_inst.initialize()
|
||||
ap.rag_mgr = rag_mgr_inst
|
||||
|
||||
# 初始化向量数据库管理器
|
||||
|
||||
0
pkg/entity/rag/__init__.py
Normal file
0
pkg/entity/rag/__init__.py
Normal file
13
pkg/entity/rag/retriever.py
Normal file
13
pkg/entity/rag/retriever.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pydantic
|
||||
|
||||
from typing import Any
|
||||
|
||||
|
||||
class RetrieveResultEntry(pydantic.BaseModel):
|
||||
id: str
|
||||
|
||||
metadata: dict[str, Any]
|
||||
|
||||
distance: float
|
||||
@@ -101,18 +101,18 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
||||
async def invoke_embedding(
|
||||
self,
|
||||
model: RuntimeEmbeddingModel,
|
||||
input_text: str,
|
||||
input_text: list[str],
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> list[float]:
|
||||
) -> list[list[float]]:
|
||||
"""调用 Embedding API
|
||||
|
||||
Args:
|
||||
query (core_entities.Query): 请求上下文
|
||||
model (RuntimeEmbeddingModel): 使用的模型信息
|
||||
input_text (str): 输入文本
|
||||
input_text (list[str]): 输入文本
|
||||
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.
|
||||
|
||||
Returns:
|
||||
list[float]: 返回的 embedding 向量
|
||||
list[list[float]]: 返回的 embedding 向量
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -145,9 +145,9 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
async def invoke_embedding(
|
||||
self,
|
||||
model: requester.RuntimeEmbeddingModel,
|
||||
input_text: str,
|
||||
input_text: list[str],
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
) -> list[float]:
|
||||
) -> list[list[float]]:
|
||||
"""调用 Embedding API"""
|
||||
self.client.api_key = model.token_mgr.get_token()
|
||||
|
||||
@@ -163,7 +163,8 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
|
||||
try:
|
||||
resp = await self.client.embeddings.create(**args)
|
||||
return resp.data[0].embedding
|
||||
|
||||
return [d.embedding for d in resp.data]
|
||||
except asyncio.TimeoutError:
|
||||
raise errors.RequesterError('请求超时')
|
||||
except openai.BadRequestError as e:
|
||||
|
||||
@@ -1,21 +1,15 @@
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import asyncio
|
||||
import traceback
|
||||
import uuid
|
||||
from pkg.rag.knowledge.services.parser import FileParser
|
||||
from pkg.rag.knowledge.services.chunker import Chunker
|
||||
from pkg.rag.knowledge.services.database import (
|
||||
KnowledgeBase,
|
||||
File,
|
||||
Chunk,
|
||||
)
|
||||
from .services import parser, chunker
|
||||
from pkg.core import app
|
||||
from pkg.rag.knowledge.services.embedder import Embedder
|
||||
from pkg.rag.knowledge.services.retriever import Retriever
|
||||
import sqlalchemy
|
||||
from ...entity.persistence import rag as persistence_rag
|
||||
from pkg.core import taskmgr
|
||||
from ...entity.rag import retriever as retriever_entities
|
||||
|
||||
|
||||
class RuntimeKnowledgeBase:
|
||||
@@ -23,9 +17,9 @@ class RuntimeKnowledgeBase:
|
||||
|
||||
knowledge_base_entity: persistence_rag.KnowledgeBase
|
||||
|
||||
parser: FileParser
|
||||
parser: parser.FileParser
|
||||
|
||||
chunker: Chunker
|
||||
chunker: chunker.Chunker
|
||||
|
||||
embedder: Embedder
|
||||
|
||||
@@ -34,8 +28,8 @@ class RuntimeKnowledgeBase:
|
||||
def __init__(self, ap: app.Application, knowledge_base_entity: persistence_rag.KnowledgeBase):
|
||||
self.ap = ap
|
||||
self.knowledge_base_entity = knowledge_base_entity
|
||||
self.parser = FileParser(ap=self.ap)
|
||||
self.chunker = Chunker(ap=self.ap)
|
||||
self.parser = parser.FileParser(ap=self.ap)
|
||||
self.chunker = chunker.Chunker(ap=self.ap)
|
||||
self.embedder = Embedder(ap=self.ap)
|
||||
self.retriever = Retriever(ap=self.ap)
|
||||
# 传递kb_id给retriever
|
||||
@@ -66,9 +60,16 @@ class RuntimeKnowledgeBase:
|
||||
raise Exception(f'No chunks extracted from file {file.file_name}')
|
||||
|
||||
task_context.set_current_action('Embedding chunks')
|
||||
|
||||
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(
|
||||
self.knowledge_base_entity.embedding_model_uuid
|
||||
)
|
||||
# embed chunks
|
||||
await self.embedder.embed_and_store(
|
||||
file_id=file.uuid, chunks=chunks_texts, embedding_model=self.knowledge_base_entity.embedding_model_uuid
|
||||
kb_id=self.knowledge_base_entity.uuid,
|
||||
file_id=file.uuid,
|
||||
chunks=chunks_texts,
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
# set file status to completed
|
||||
@@ -79,7 +80,8 @@ class RuntimeKnowledgeBase:
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Error storing file {file.file_id}: {e}')
|
||||
self.ap.logger.error(f'Error storing file {file.uuid}: {e}')
|
||||
traceback.print_exc()
|
||||
# set file status to failed
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_rag.File)
|
||||
@@ -89,7 +91,7 @@ class RuntimeKnowledgeBase:
|
||||
|
||||
raise
|
||||
|
||||
async def store_file(self, file_id: str) -> int:
|
||||
async def store_file(self, file_id: str) -> str:
|
||||
# pre checking
|
||||
if not await self.ap.storage_mgr.storage_provider.exists(file_id):
|
||||
raise Exception(f'File {file_id} not found')
|
||||
@@ -97,22 +99,24 @@ class RuntimeKnowledgeBase:
|
||||
file_uuid = str(uuid.uuid4())
|
||||
kb_id = self.knowledge_base_entity.uuid
|
||||
file_name = file_id
|
||||
extension = os.path.splitext(file_id)[1].lstrip('.')
|
||||
extension = file_name.split('.')[-1]
|
||||
|
||||
file = persistence_rag.File(
|
||||
uuid=file_uuid,
|
||||
kb_id=kb_id,
|
||||
file_name=file_name,
|
||||
extension=extension,
|
||||
status='pending',
|
||||
)
|
||||
file_obj_data = {
|
||||
'uuid': file_uuid,
|
||||
'kb_id': kb_id,
|
||||
'file_name': file_name,
|
||||
'extension': extension,
|
||||
'status': 'pending',
|
||||
}
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(**file.to_dict()))
|
||||
file_obj = persistence_rag.File(**file_obj_data)
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(file_obj_data))
|
||||
|
||||
# run background task asynchronously
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self._store_file_task(file, task_context=ctx),
|
||||
self._store_file_task(file_obj, task_context=ctx),
|
||||
kind='knowledge-operation',
|
||||
name=f'knowledge-store-file-{file_id}',
|
||||
label=f'Store file {file_id}',
|
||||
@@ -120,6 +124,12 @@ class RuntimeKnowledgeBase:
|
||||
)
|
||||
return wrapper.id
|
||||
|
||||
async def retrieve(self, query: str) -> list[retriever_entities.RetrieveResultEntry]:
|
||||
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(
|
||||
self.knowledge_base_entity.embedding_model_uuid
|
||||
)
|
||||
return await self.retriever.retrieve(self.knowledge_base_entity.uuid, query, embedding_model)
|
||||
|
||||
async def dispose(self):
|
||||
pass
|
||||
|
||||
@@ -134,7 +144,7 @@ class RAGManager:
|
||||
self.knowledge_bases = []
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
await self.load_knowledge_bases_from_db()
|
||||
|
||||
async def load_knowledge_bases_from_db(self):
|
||||
self.ap.logger.info('Loading knowledge bases from db...')
|
||||
@@ -183,57 +193,6 @@ class RAGManager:
|
||||
self.knowledge_bases.remove(kb)
|
||||
return
|
||||
|
||||
async def store_data(self, file_path: str, kb_id: str, file_type: str, file_id: str = None):
|
||||
"""
|
||||
Parses, chunks, embeds, and stores data from a given file into the RAG system.
|
||||
Associates the file with a knowledge base using kb_id in the File table.
|
||||
"""
|
||||
self.ap.logger.info(f'Starting data storage process for file: {file_path}')
|
||||
session = SessionLocal()
|
||||
file_obj = None
|
||||
|
||||
try:
|
||||
kb = session.query(KnowledgeBase).filter_by(id=kb_id).first()
|
||||
if not kb:
|
||||
self.ap.logger.info(f'Knowledge Base "{kb_id}" does not exist. ')
|
||||
return
|
||||
# get embedding model
|
||||
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(kb.embedding_model_uuid)
|
||||
file_name = os.path.basename(file_path)
|
||||
text = await self.parser.parse(file_path)
|
||||
if not text:
|
||||
self.ap.logger.warning(f'No text extracted from file {file_path}. ')
|
||||
return
|
||||
|
||||
chunks_texts = await self.chunker.chunk(text)
|
||||
self.ap.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.")
|
||||
await self.embedder.embed_and_store(file_id=file_id, chunks=chunks_texts, embedding_model=embedding_model)
|
||||
self.ap.logger.info(f'Data storage process completed for file: {file_path}')
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
self.ap.logger.error(f'Error in store_data for file {file_path}: {str(e)}', exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
if file_id:
|
||||
file_obj = session.query(File).filter_by(id=file_id).first()
|
||||
if file_obj:
|
||||
file_obj.status = 1
|
||||
session.close()
|
||||
|
||||
async def retrieve_data(self, query: str):
|
||||
"""
|
||||
Retrieves relevant data chunks based on a given query using the configured retriever.
|
||||
"""
|
||||
self.ap.logger.info(f"Starting data retrieval process for query: '{query}'")
|
||||
try:
|
||||
retrieved_chunks = await self.retriever.retrieve(query)
|
||||
self.ap.logger.info(f'Successfully retrieved {len(retrieved_chunks)} chunks for query.')
|
||||
return retrieved_chunks
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"Error in retrieve_data for query '{query}': {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def delete_data_by_file_id(self, file_id: str):
|
||||
"""
|
||||
Deletes all data associated with a specific file ID, including its chunks and vectors,
|
||||
@@ -243,7 +202,7 @@ class RAGManager:
|
||||
session = SessionLocal()
|
||||
try:
|
||||
# delete vectors
|
||||
await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id)
|
||||
await asyncio.to_thread(self.ap.vector_db_mgr.vector_db.delete_by_file_id_sync, file_id)
|
||||
self.ap.logger.info(f'Deleted embeddings from ChromaDB for file_id: {file_id}')
|
||||
|
||||
chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all()
|
||||
@@ -333,74 +292,18 @@ class RAGManager:
|
||||
if session.is_active:
|
||||
session.close()
|
||||
|
||||
async def get_file_content_by_file_id(self, file_id: str) -> str:
|
||||
file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id)
|
||||
# async def get_file_content_by_file_id(self, file_id: str) -> str:
|
||||
# file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id)
|
||||
|
||||
_, ext = os.path.splitext(file_id.lower())
|
||||
ext = ext.lstrip('.')
|
||||
# _, ext = os.path.splitext(file_id.lower())
|
||||
# ext = ext.lstrip('.')
|
||||
|
||||
try:
|
||||
text = file_bytes.decode('utf-8')
|
||||
except UnicodeDecodeError:
|
||||
return '[非文本文件或编码无法识别]'
|
||||
# try:
|
||||
# text = file_bytes.decode('utf-8')
|
||||
# except UnicodeDecodeError:
|
||||
# return '[非文本文件或编码无法识别]'
|
||||
|
||||
if ext in ['txt', 'md', 'csv', 'log', 'py', 'html']:
|
||||
return text
|
||||
else:
|
||||
return f'[未知类型: .{ext}]'
|
||||
|
||||
async def relate_file_id_with_kb(self, knowledge_base_uuid: str, file_id: str) -> None:
|
||||
"""
|
||||
Associates a file with a knowledge base by updating the kb_id in the File table.
|
||||
"""
|
||||
self.ap.logger.info(f'Associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}')
|
||||
session = SessionLocal()
|
||||
try:
|
||||
# 查询知识库是否存在
|
||||
kb = session.query(KnowledgeBase).filter_by(id=knowledge_base_uuid).first()
|
||||
if not kb:
|
||||
self.ap.logger.error(f'Knowledge Base with UUID {knowledge_base_uuid} not found.')
|
||||
return
|
||||
|
||||
if not await self.ap.storage_mgr.storage_provider.exists(file_id):
|
||||
self.ap.logger.error(f'File with ID {file_id} does not exist.')
|
||||
return
|
||||
self.ap.logger.info(f'File with ID {file_id} exists, proceeding with association.')
|
||||
# add new file record
|
||||
file_to_update = File(
|
||||
id=file_id,
|
||||
kb_id=kb.id,
|
||||
file_name=file_id,
|
||||
path=os.path.join('data', 'storage', file_id),
|
||||
file_type=os.path.splitext(file_id)[1].lstrip('.'),
|
||||
status=0,
|
||||
)
|
||||
session.add(file_to_update)
|
||||
session.commit()
|
||||
self.ap.logger.info(
|
||||
f'Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}'
|
||||
)
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
self.ap.logger.error(
|
||||
f'Error associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}: {str(e)}',
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
# 进行文件解析
|
||||
try:
|
||||
await self.store_data(
|
||||
file_path=os.path.join('data', 'storage', file_id),
|
||||
kb_id=knowledge_base_uuid,
|
||||
file_type=os.path.splitext(file_id)[1].lstrip('.'),
|
||||
file_id=file_id,
|
||||
)
|
||||
except Exception:
|
||||
# 如果存储数据时出错,更新文件状态为失败
|
||||
file_obj = session.query(File).filter_by(id=file_id).first()
|
||||
if file_obj:
|
||||
file_obj.status = 2
|
||||
session.commit()
|
||||
self.ap.logger.error(f'Error storing data for file ID {file_id}', exc_info=True)
|
||||
|
||||
session.close()
|
||||
# if ext in ['txt', 'md', 'csv', 'log', 'py', 'html']:
|
||||
# return text
|
||||
# else:
|
||||
# return f'[未知类型: .{ext}]'
|
||||
@@ -1,12 +1,10 @@
|
||||
# 封装异步操作
|
||||
import asyncio
|
||||
import logging
|
||||
from pkg.rag.knowledge.services.database import SessionLocal
|
||||
|
||||
|
||||
class BaseService:
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.db_session_factory = SessionLocal
|
||||
pass
|
||||
|
||||
async def _run_sync(self, func, *args, **kwargs):
|
||||
"""
|
||||
@@ -14,13 +12,4 @@ class BaseService:
|
||||
如果第一个参数是 session,则在 to_thread 中获取新的 session。
|
||||
"""
|
||||
|
||||
if getattr(func, '__name__', '').startswith('_db_'):
|
||||
session = await asyncio.to_thread(self.db_session_factory)
|
||||
try:
|
||||
result = await asyncio.to_thread(func, session, *args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
session.close()
|
||||
else:
|
||||
# 否则,直接运行同步函数
|
||||
return await asyncio.to_thread(func, *args, **kwargs)
|
||||
return await asyncio.to_thread(func, *args, **kwargs)
|
||||
|
||||
@@ -1,24 +1,21 @@
|
||||
# services/chunker.py
|
||||
import logging
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List
|
||||
from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync
|
||||
from pkg.rag.knowledge.services import base_service
|
||||
from pkg.core import app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Chunker(BaseService):
|
||||
class Chunker(base_service.BaseService):
|
||||
"""
|
||||
A class for splitting long texts into smaller, overlapping chunks.
|
||||
"""
|
||||
|
||||
def __init__(self, ap: app.Application, chunk_size: int = 500, chunk_overlap: int = 50):
|
||||
super().__init__(ap) # Initialize BaseService
|
||||
self.ap = ap
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
if self.chunk_overlap >= self.chunk_size:
|
||||
self.logger.warning(
|
||||
self.ap.logger.warning(
|
||||
'Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.'
|
||||
)
|
||||
|
||||
|
||||
@@ -1,23 +0,0 @@
|
||||
# 全部迁移过去
|
||||
|
||||
from pkg.entity.persistence.rag import (
|
||||
create_db_and_tables,
|
||||
SessionLocal,
|
||||
Base,
|
||||
engine,
|
||||
KnowledgeBase,
|
||||
File,
|
||||
Chunk,
|
||||
Vector,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"create_db_and_tables",
|
||||
"SessionLocal",
|
||||
"Base",
|
||||
"engine",
|
||||
"KnowledgeBase",
|
||||
"File",
|
||||
"Chunk",
|
||||
"Vector",
|
||||
]
|
||||
|
||||
@@ -1,12 +1,11 @@
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import numpy as np
|
||||
import uuid
|
||||
from typing import List
|
||||
from sqlalchemy.orm import Session
|
||||
from pkg.rag.knowledge.services.base_service import BaseService
|
||||
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
|
||||
from ....entity.persistence import rag as persistence_rag
|
||||
from ....core import app
|
||||
from ....provider.modelmgr.requester import RuntimeEmbeddingModel
|
||||
import sqlalchemy
|
||||
|
||||
|
||||
class Embedder(BaseService):
|
||||
@@ -14,74 +13,41 @@ class Embedder(BaseService):
|
||||
super().__init__()
|
||||
self.ap = ap
|
||||
|
||||
def _db_save_chunks_sync(self, session: Session, file_id: int, chunks_texts: List[str]):
|
||||
"""
|
||||
Saves chunks to the relational database and returns the created Chunk objects.
|
||||
This function assumes it's called within a context where the session
|
||||
will be committed/rolled back and closed by the caller.
|
||||
"""
|
||||
self.ap.logger.debug(f'Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).')
|
||||
chunk_objects = []
|
||||
for text in chunks_texts:
|
||||
chunk = Chunk(file_id=file_id, text=text)
|
||||
session.add(chunk)
|
||||
chunk_objects.append(chunk)
|
||||
session.flush() # This populates the .id attribute for each new chunk object
|
||||
self.ap.logger.debug(f'Successfully added {len(chunk_objects)} chunk entries to DB.')
|
||||
return chunk_objects
|
||||
|
||||
async def embed_and_store(
|
||||
self, file_id: int, chunks: List[str], embedding_model: RuntimeEmbeddingModel
|
||||
) -> List[Chunk]:
|
||||
session = SessionLocal() # Start a session that will live for the whole operation
|
||||
chunk_objects = []
|
||||
try:
|
||||
# 1. Save chunks to the relational database first to get their IDs
|
||||
# We call _db_save_chunks_sync directly without _run_sync's session management
|
||||
# because we manage the session here across multiple async calls.
|
||||
chunk_objects = await asyncio.to_thread(self._db_save_chunks_sync, session, file_id, chunks)
|
||||
session.commit() # Commit chunks to make their IDs permanent and accessible
|
||||
self, kb_id: str, file_id: str, chunks: List[str], embedding_model: RuntimeEmbeddingModel
|
||||
) -> list[persistence_rag.Chunk]:
|
||||
# save chunk to db
|
||||
chunk_entities: list[persistence_rag.Chunk] = []
|
||||
chunk_ids: list[str] = []
|
||||
|
||||
if not chunk_objects:
|
||||
self.ap.logger.warning(
|
||||
f'No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.'
|
||||
)
|
||||
return []
|
||||
for chunk_text in chunks:
|
||||
chunk_uuid = str(uuid.uuid4())
|
||||
chunk_ids.append(chunk_uuid)
|
||||
chunk_entity = persistence_rag.Chunk(uuid=chunk_uuid, file_id=file_id, text=chunk_text)
|
||||
chunk_entities.append(chunk_entity)
|
||||
|
||||
# get the embeddings for the chunks
|
||||
embeddings: list[list[float]] = []
|
||||
chunk_dicts = [
|
||||
self.ap.persistence_mgr.serialize_model(persistence_rag.Chunk, chunk) for chunk in chunk_entities
|
||||
]
|
||||
|
||||
for chunk in chunks:
|
||||
result = await embedding_model.requester.invoke_embedding(
|
||||
model=embedding_model,
|
||||
input_text=chunk,
|
||||
)
|
||||
embeddings.append(result)
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.Chunk).values(chunk_dicts))
|
||||
|
||||
embeddings_np = np.array(embeddings, dtype=np.float32)
|
||||
# get embeddings
|
||||
embeddings_list: list[list[float]] = await embedding_model.requester.invoke_embedding(
|
||||
model=embedding_model,
|
||||
input_text=chunks,
|
||||
extra_args={}, # TODO: add extra args
|
||||
)
|
||||
|
||||
chunk_ids = [c.id for c in chunk_objects]
|
||||
# collection名用kb_id(file对象有kb_id字段)
|
||||
kb_id = session.query(Chunk).filter_by(id=chunk_ids[0]).first().file.kb_id if chunk_ids else None
|
||||
if not kb_id:
|
||||
self.ap.logger.warning('无法获取kb_id,向量存储失败')
|
||||
return chunk_objects
|
||||
chroma_ids = [f'{file_id}_{cid}' for cid in chunk_ids]
|
||||
metadatas = [{'file_id': file_id, 'chunk_id': cid} for cid in chunk_ids]
|
||||
await self._run_sync(
|
||||
self.ap.vector_db_mgr.vector_db.add_embeddings,
|
||||
kb_id,
|
||||
chroma_ids,
|
||||
embeddings_np,
|
||||
metadatas,
|
||||
chunks,
|
||||
)
|
||||
self.ap.logger.info(f'Successfully saved {len(chunk_objects)} embeddings to VectorDB.')
|
||||
return chunk_objects
|
||||
# save embeddings to vdb
|
||||
await self._run_sync(
|
||||
self.ap.vector_db_mgr.vector_db.add_embeddings,
|
||||
kb_id,
|
||||
chunk_ids,
|
||||
embeddings_list,
|
||||
chunk_dicts,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
session.rollback() # Rollback on any error
|
||||
self.ap.logger.error(f'Failed to process and store data for file_id {file_id}: {e}', exc_info=True)
|
||||
raise # Re-raise the exception to propagate it
|
||||
finally:
|
||||
session.close() # Ensure the session is always closed
|
||||
self.ap.logger.info(f'Successfully saved {len(chunk_entities)} embeddings to Knowledge Base.')
|
||||
|
||||
return chunk_entities
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import PyPDF2
|
||||
import io
|
||||
from docx import Document
|
||||
|
||||
@@ -1,99 +1,46 @@
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import numpy as np # Make sure numpy is imported
|
||||
from typing import List, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from pkg.rag.knowledge.services.base_service import BaseService
|
||||
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
|
||||
from pkg.vector.vdb import VectorDatabase
|
||||
|
||||
from . import base_service
|
||||
from ....core import app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from ....provider.modelmgr.requester import RuntimeEmbeddingModel
|
||||
from ....entity.rag import retriever as retriever_entities
|
||||
|
||||
|
||||
class Retriever(BaseService):
|
||||
class Retriever(base_service.BaseService):
|
||||
def __init__(self, ap: app.Application):
|
||||
super().__init__()
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.ap = ap
|
||||
self.vector_db: VectorDatabase = ap.vector_db_mgr.vector_db
|
||||
|
||||
async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
|
||||
if not self.embedding_model:
|
||||
raise RuntimeError('Retriever embedding model not loaded. Please check Retriever initialization.')
|
||||
async def retrieve(
|
||||
self, kb_id: str, query: str, embedding_model: RuntimeEmbeddingModel, k: int = 5
|
||||
) -> list[retriever_entities.RetrieveResultEntry]:
|
||||
self.ap.logger.info(f"Retrieving for query: '{query}' with k={k} using {embedding_model.model_entity.uuid}")
|
||||
|
||||
self.logger.info(f"Retrieving for query: '{query}' with k={k} using {self.model_name_key}")
|
||||
query_embedding: list[float] = await embedding_model.requester.invoke_embedding(
|
||||
model=embedding_model,
|
||||
input_text=[query],
|
||||
extra_args={}, # TODO: add extra args
|
||||
)
|
||||
|
||||
query_embedding: List[float] = await self.embedding_model.embed_query(query)
|
||||
query_embedding_np = np.array([query_embedding], dtype=np.float32)
|
||||
|
||||
# collection名用kb_id(假设retriever有kb_id属性或通过ap传递)
|
||||
kb_id = getattr(self, 'kb_id', None)
|
||||
if not kb_id:
|
||||
self.logger.warning('无法获取kb_id,向量检索失败')
|
||||
return []
|
||||
chroma_results = await self._run_sync(self.vector_db.search, kb_id, query_embedding_np, k)
|
||||
chroma_results = await self._run_sync(self.ap.vector_db_mgr.vector_db.search, kb_id, query_embedding[0], k)
|
||||
|
||||
# 'ids' is always returned by ChromaDB, even if not explicitly in 'include'
|
||||
matched_chroma_ids = chroma_results.get('ids', [[]])[0]
|
||||
distances = chroma_results.get('distances', [[]])[0]
|
||||
chroma_metadatas = chroma_results.get('metadatas', [[]])[0]
|
||||
chroma_documents = chroma_results.get('documents', [[]])[0]
|
||||
|
||||
if not matched_chroma_ids:
|
||||
self.logger.info('No relevant chunks found in Chroma.')
|
||||
self.ap.logger.info('No relevant chunks found in Chroma.')
|
||||
return []
|
||||
|
||||
db_chunk_ids = []
|
||||
for metadata in chroma_metadatas:
|
||||
if 'chunk_id' in metadata:
|
||||
db_chunk_ids.append(metadata['chunk_id'])
|
||||
else:
|
||||
self.logger.warning(f"Metadata missing 'chunk_id': {metadata}. Skipping this entry.")
|
||||
result: list[retriever_entities.RetrieveResultEntry] = []
|
||||
|
||||
if not db_chunk_ids:
|
||||
self.logger.warning('No valid chunk_ids extracted from Chroma results metadata.')
|
||||
return []
|
||||
|
||||
self.logger.info(f'Fetching {len(db_chunk_ids)} chunk details from relational database...')
|
||||
chunks_from_db = await self._run_sync(
|
||||
lambda cids: self._db_get_chunks_sync(
|
||||
SessionLocal(), cids
|
||||
), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync
|
||||
db_chunk_ids,
|
||||
)
|
||||
|
||||
chunk_map = {chunk.id: chunk for chunk in chunks_from_db}
|
||||
results_list: List[Dict[str, Any]] = []
|
||||
|
||||
for i, chroma_id in enumerate(matched_chroma_ids):
|
||||
try:
|
||||
# Ensure original_chunk_id is int for DB lookup
|
||||
original_chunk_id = int(chroma_id.split('_')[-1])
|
||||
except (ValueError, IndexError):
|
||||
self.logger.warning(f'Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.')
|
||||
continue
|
||||
|
||||
chunk_text_from_chroma = chroma_documents[i]
|
||||
distance = float(distances[i])
|
||||
file_id_from_chroma = chroma_metadatas[i].get('file_id')
|
||||
|
||||
chunk_from_db = chunk_map.get(original_chunk_id)
|
||||
|
||||
results_list.append(
|
||||
{
|
||||
'chunk_id': original_chunk_id,
|
||||
'text': chunk_from_db.text if chunk_from_db else chunk_text_from_chroma,
|
||||
'distance': distance,
|
||||
'file_id': file_id_from_chroma,
|
||||
}
|
||||
for i, id in enumerate(matched_chroma_ids):
|
||||
entry = retriever_entities.RetrieveResultEntry(
|
||||
id=id,
|
||||
metadata=chroma_metadatas[i],
|
||||
distance=distances[i],
|
||||
)
|
||||
result.append(entry)
|
||||
|
||||
self.logger.info(f'Retrieved {len(results_list)} chunks.')
|
||||
return results_list
|
||||
|
||||
def _db_get_chunks_sync(self, session: Session, chunk_ids: List[int]) -> List[Chunk]:
|
||||
self.logger.debug(f'Fetching {len(chunk_ids)} chunk details from database (sync).')
|
||||
chunks = session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all()
|
||||
session.close()
|
||||
return chunks
|
||||
return result
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import abc
|
||||
from typing import Any, List, Dict
|
||||
from typing import Any, Dict
|
||||
import numpy as np
|
||||
|
||||
|
||||
@@ -9,10 +9,10 @@ class VectorDatabase(abc.ABC):
|
||||
def add_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
ids: List[str],
|
||||
embeddings: np.ndarray,
|
||||
metadatas: List[Dict[str, Any]],
|
||||
documents: List[str],
|
||||
ids: list[str],
|
||||
embeddings_list: list[list[float]],
|
||||
metadatas: list[dict[str, Any]],
|
||||
documents: list[str],
|
||||
) -> None:
|
||||
"""向指定 collection 添加向量数据。"""
|
||||
pass
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import numpy as np
|
||||
from typing import Any, List, Dict
|
||||
import chromadb
|
||||
from typing import Any
|
||||
from chromadb import PersistentClient
|
||||
from pkg.vector.vdb import VectorDatabase
|
||||
from pkg.core import app
|
||||
@@ -12,7 +12,7 @@ class ChromaVectorDatabase(VectorDatabase):
|
||||
self.client = PersistentClient(path=base_path)
|
||||
self._collections = {}
|
||||
|
||||
def get_or_create_collection(self, collection: str):
|
||||
def get_or_create_collection(self, collection: str) -> chromadb.Collection:
|
||||
if collection not in self._collections:
|
||||
self._collections[collection] = self.client.get_or_create_collection(name=collection)
|
||||
self.ap.logger.info(f"Chroma collection '{collection}' accessed/created.")
|
||||
@@ -21,26 +21,25 @@ class ChromaVectorDatabase(VectorDatabase):
|
||||
def add_embeddings(
|
||||
self,
|
||||
collection: str,
|
||||
ids: List[str],
|
||||
embeddings: np.ndarray,
|
||||
metadatas: List[Dict[str, Any]],
|
||||
documents: List[str],
|
||||
ids: list[str],
|
||||
embeddings_list: list[list[float]],
|
||||
metadatas: list[dict[str, Any]],
|
||||
) -> None:
|
||||
col = self.get_or_create_collection(collection)
|
||||
col.add(embeddings=embeddings.tolist(), ids=ids, metadatas=metadatas, documents=documents)
|
||||
col.add(embeddings=embeddings_list, ids=ids, metadatas=metadatas)
|
||||
self.ap.logger.info(f"Added {len(ids)} embeddings to Chroma collection '{collection}'.")
|
||||
|
||||
def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]:
|
||||
def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]:
|
||||
col = self.get_or_create_collection(collection)
|
||||
results = col.query(
|
||||
query_embeddings=query_embedding.tolist(),
|
||||
query_embeddings=query_embedding,
|
||||
n_results=k,
|
||||
include=['metadatas', 'distances', 'documents'],
|
||||
)
|
||||
self.ap.logger.debug(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.")
|
||||
self.ap.logger.info(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.")
|
||||
return results
|
||||
|
||||
def delete_by_metadata(self, collection: str, where: Dict[str, Any]) -> None:
|
||||
def delete_by_metadata(self, collection: str, where: dict[str, Any]) -> None:
|
||||
col = self.get_or_create_collection(collection)
|
||||
col.delete(where=where)
|
||||
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with filter: {where}")
|
||||
|
||||
Reference in New Issue
Block a user