feat(rag): make embedding and retrieving available

This commit is contained in:
Junyan Qin
2025-07-16 21:17:18 +08:00
parent f731115805
commit 2f2db4d445
20 changed files with 180 additions and 368 deletions

View File

@@ -86,10 +86,10 @@ class RouterGroup(abc.ABC):
try: try:
return await f(*args, **kwargs) return await f(*args, **kwargs)
except Exception: # 自动 500 except Exception as e: # 自动 500
traceback.print_exc() traceback.print_exc()
# return self.http_status(500, -2, str(e)) # return self.http_status(500, -2, str(e))
return self.http_status(500, -2, 'internal server error') return self.http_status(500, -2, str(e))
new_f = handler_error new_f = handler_error
new_f.__name__ = (self.name + rule).replace('/', '__') new_f.__name__ = (self.name + rule).replace('/', '__')

View File

@@ -34,8 +34,9 @@ class FilesRouterGroup(group.RouterGroup):
file_bytes = await asyncio.to_thread(file.stream.read) file_bytes = await asyncio.to_thread(file.stream.read)
extension = file.filename.split('.')[-1] extension = file.filename.split('.')[-1]
file_name = file.filename.split('.')[0]
file_key = str(uuid.uuid4()) + '.' + extension file_key = file_name + '_' + str(uuid.uuid4())[:8] + '.' + extension
# save file to storage # save file to storage
await self.ap.storage_mgr.storage_provider.save(file_key, file_bytes) await self.ap.storage_mgr.storage_provider.save(file_key, file_bytes)
return self.success( return self.success(

View File

@@ -72,3 +72,13 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
async def delete_specific_file_in_kb(file_id: str, knowledge_base_uuid: str) -> str: async def delete_specific_file_in_kb(file_id: str, knowledge_base_uuid: str) -> str:
await self.ap.knowledge_service.delete_file(knowledge_base_uuid, file_id) await self.ap.knowledge_service.delete_file(knowledge_base_uuid, file_id)
return self.success({}) return self.success({})
@self.route(
'/<knowledge_base_uuid>/retrieve',
methods=['POST'],
)
async def retrieve_knowledge_base(knowledge_base_uuid: str) -> str:
json_data = await quart.request.json
query = json_data.get('query')
results = await self.ap.knowledge_service.retrieve_knowledge_base(knowledge_base_uuid, query)
return self.success(data={'results': results})

View File

@@ -67,6 +67,13 @@ class KnowledgeService:
raise Exception('Knowledge base not found') raise Exception('Knowledge base not found')
return await runtime_kb.store_file(file_id) return await runtime_kb.store_file(file_id)
async def retrieve_knowledge_base(self, kb_uuid: str, query: str) -> list[dict]:
"""检索知识库"""
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
if runtime_kb is None:
raise Exception('Knowledge base not found')
return [result.model_dump() for result in await runtime_kb.retrieve(query)]
async def get_files_by_knowledge_base(self, kb_uuid: str) -> list[dict]: async def get_files_by_knowledge_base(self, kb_uuid: str) -> list[dict]:
"""获取知识库文件""" """获取知识库文件"""
result = await self.ap.persistence_mgr.execute_async( result = await self.ap.persistence_mgr.execute_async(

View File

@@ -186,6 +186,6 @@ class EmbeddingModelsService:
await runtime_embedding_model.requester.invoke_embedding( await runtime_embedding_model.requester.invoke_embedding(
model=runtime_embedding_model, model=runtime_embedding_model,
input_text='Hello, world!', input_text=['Hello, world!'],
extra_args={}, extra_args={},
) )

View File

@@ -28,7 +28,7 @@ from ..storage import mgr as storagemgr
from ..utils import logcache from ..utils import logcache
from . import taskmgr from . import taskmgr
from . import entities as core_entities from . import entities as core_entities
from ..rag.knowledge import mgr as rag_mgr from ..rag.knowledge import kbmgr as rag_mgr
from ..vector import mgr as vectordb_mgr from ..vector import mgr as vectordb_mgr

View File

@@ -9,7 +9,7 @@ from ...command import cmdmgr
from ...provider.session import sessionmgr as llm_session_mgr from ...provider.session import sessionmgr as llm_session_mgr
from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr
from ...provider.tools import toolmgr as llm_tool_mgr from ...provider.tools import toolmgr as llm_tool_mgr
from ...rag.knowledge import mgr as rag_mgr from ...rag.knowledge import kbmgr as rag_mgr
from ...platform import botmgr as im_mgr from ...platform import botmgr as im_mgr
from ...persistence import mgr as persistencemgr from ...persistence import mgr as persistencemgr
from ...api.http.controller import main as http_controller from ...api.http.controller import main as http_controller
@@ -92,7 +92,7 @@ class BuildAppStage(stage.BootingStage):
ap.pipeline_mgr = pipeline_mgr ap.pipeline_mgr = pipeline_mgr
rag_mgr_inst = rag_mgr.RAGManager(ap) rag_mgr_inst = rag_mgr.RAGManager(ap)
await rag_mgr_inst.initialize_rag_system() await rag_mgr_inst.initialize()
ap.rag_mgr = rag_mgr_inst ap.rag_mgr = rag_mgr_inst
# 初始化向量数据库管理器 # 初始化向量数据库管理器

View File

View File

@@ -0,0 +1,13 @@
from __future__ import annotations
import pydantic
from typing import Any
class RetrieveResultEntry(pydantic.BaseModel):
id: str
metadata: dict[str, Any]
distance: float

View File

@@ -101,18 +101,18 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
async def invoke_embedding( async def invoke_embedding(
self, self,
model: RuntimeEmbeddingModel, model: RuntimeEmbeddingModel,
input_text: str, input_text: list[str],
extra_args: dict[str, typing.Any] = {}, extra_args: dict[str, typing.Any] = {},
) -> list[float]: ) -> list[list[float]]:
"""调用 Embedding API """调用 Embedding API
Args: Args:
query (core_entities.Query): 请求上下文 query (core_entities.Query): 请求上下文
model (RuntimeEmbeddingModel): 使用的模型信息 model (RuntimeEmbeddingModel): 使用的模型信息
input_text (str): 输入文本 input_text (list[str]): 输入文本
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.
Returns: Returns:
list[float]: 返回的 embedding 向量 list[list[float]]: 返回的 embedding 向量
""" """
pass pass

View File

@@ -145,9 +145,9 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
async def invoke_embedding( async def invoke_embedding(
self, self,
model: requester.RuntimeEmbeddingModel, model: requester.RuntimeEmbeddingModel,
input_text: str, input_text: list[str],
extra_args: dict[str, typing.Any] = {}, extra_args: dict[str, typing.Any] = {},
) -> list[float]: ) -> list[list[float]]:
"""调用 Embedding API""" """调用 Embedding API"""
self.client.api_key = model.token_mgr.get_token() self.client.api_key = model.token_mgr.get_token()
@@ -163,7 +163,8 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
try: try:
resp = await self.client.embeddings.create(**args) resp = await self.client.embeddings.create(**args)
return resp.data[0].embedding
return [d.embedding for d in resp.data]
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise errors.RequesterError('请求超时') raise errors.RequesterError('请求超时')
except openai.BadRequestError as e: except openai.BadRequestError as e:

View File

@@ -1,21 +1,15 @@
from __future__ import annotations from __future__ import annotations
import os
import asyncio import asyncio
import traceback import traceback
import uuid import uuid
from pkg.rag.knowledge.services.parser import FileParser from .services import parser, chunker
from pkg.rag.knowledge.services.chunker import Chunker
from pkg.rag.knowledge.services.database import (
KnowledgeBase,
File,
Chunk,
)
from pkg.core import app from pkg.core import app
from pkg.rag.knowledge.services.embedder import Embedder from pkg.rag.knowledge.services.embedder import Embedder
from pkg.rag.knowledge.services.retriever import Retriever from pkg.rag.knowledge.services.retriever import Retriever
import sqlalchemy import sqlalchemy
from ...entity.persistence import rag as persistence_rag from ...entity.persistence import rag as persistence_rag
from pkg.core import taskmgr from pkg.core import taskmgr
from ...entity.rag import retriever as retriever_entities
class RuntimeKnowledgeBase: class RuntimeKnowledgeBase:
@@ -23,9 +17,9 @@ class RuntimeKnowledgeBase:
knowledge_base_entity: persistence_rag.KnowledgeBase knowledge_base_entity: persistence_rag.KnowledgeBase
parser: FileParser parser: parser.FileParser
chunker: Chunker chunker: chunker.Chunker
embedder: Embedder embedder: Embedder
@@ -34,8 +28,8 @@ class RuntimeKnowledgeBase:
def __init__(self, ap: app.Application, knowledge_base_entity: persistence_rag.KnowledgeBase): def __init__(self, ap: app.Application, knowledge_base_entity: persistence_rag.KnowledgeBase):
self.ap = ap self.ap = ap
self.knowledge_base_entity = knowledge_base_entity self.knowledge_base_entity = knowledge_base_entity
self.parser = FileParser(ap=self.ap) self.parser = parser.FileParser(ap=self.ap)
self.chunker = Chunker(ap=self.ap) self.chunker = chunker.Chunker(ap=self.ap)
self.embedder = Embedder(ap=self.ap) self.embedder = Embedder(ap=self.ap)
self.retriever = Retriever(ap=self.ap) self.retriever = Retriever(ap=self.ap)
# 传递kb_id给retriever # 传递kb_id给retriever
@@ -66,9 +60,16 @@ class RuntimeKnowledgeBase:
raise Exception(f'No chunks extracted from file {file.file_name}') raise Exception(f'No chunks extracted from file {file.file_name}')
task_context.set_current_action('Embedding chunks') task_context.set_current_action('Embedding chunks')
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(
self.knowledge_base_entity.embedding_model_uuid
)
# embed chunks # embed chunks
await self.embedder.embed_and_store( await self.embedder.embed_and_store(
file_id=file.uuid, chunks=chunks_texts, embedding_model=self.knowledge_base_entity.embedding_model_uuid kb_id=self.knowledge_base_entity.uuid,
file_id=file.uuid,
chunks=chunks_texts,
embedding_model=embedding_model,
) )
# set file status to completed # set file status to completed
@@ -79,7 +80,8 @@ class RuntimeKnowledgeBase:
) )
except Exception as e: except Exception as e:
self.ap.logger.error(f'Error storing file {file.file_id}: {e}') self.ap.logger.error(f'Error storing file {file.uuid}: {e}')
traceback.print_exc()
# set file status to failed # set file status to failed
await self.ap.persistence_mgr.execute_async( await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_rag.File) sqlalchemy.update(persistence_rag.File)
@@ -89,7 +91,7 @@ class RuntimeKnowledgeBase:
raise raise
async def store_file(self, file_id: str) -> int: async def store_file(self, file_id: str) -> str:
# pre checking # pre checking
if not await self.ap.storage_mgr.storage_provider.exists(file_id): if not await self.ap.storage_mgr.storage_provider.exists(file_id):
raise Exception(f'File {file_id} not found') raise Exception(f'File {file_id} not found')
@@ -97,22 +99,24 @@ class RuntimeKnowledgeBase:
file_uuid = str(uuid.uuid4()) file_uuid = str(uuid.uuid4())
kb_id = self.knowledge_base_entity.uuid kb_id = self.knowledge_base_entity.uuid
file_name = file_id file_name = file_id
extension = os.path.splitext(file_id)[1].lstrip('.') extension = file_name.split('.')[-1]
file = persistence_rag.File( file_obj_data = {
uuid=file_uuid, 'uuid': file_uuid,
kb_id=kb_id, 'kb_id': kb_id,
file_name=file_name, 'file_name': file_name,
extension=extension, 'extension': extension,
status='pending', 'status': 'pending',
) }
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(**file.to_dict())) file_obj = persistence_rag.File(**file_obj_data)
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(file_obj_data))
# run background task asynchronously # run background task asynchronously
ctx = taskmgr.TaskContext.new() ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task( wrapper = self.ap.task_mgr.create_user_task(
self._store_file_task(file, task_context=ctx), self._store_file_task(file_obj, task_context=ctx),
kind='knowledge-operation', kind='knowledge-operation',
name=f'knowledge-store-file-{file_id}', name=f'knowledge-store-file-{file_id}',
label=f'Store file {file_id}', label=f'Store file {file_id}',
@@ -120,6 +124,12 @@ class RuntimeKnowledgeBase:
) )
return wrapper.id return wrapper.id
async def retrieve(self, query: str) -> list[retriever_entities.RetrieveResultEntry]:
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(
self.knowledge_base_entity.embedding_model_uuid
)
return await self.retriever.retrieve(self.knowledge_base_entity.uuid, query, embedding_model)
async def dispose(self): async def dispose(self):
pass pass
@@ -134,7 +144,7 @@ class RAGManager:
self.knowledge_bases = [] self.knowledge_bases = []
async def initialize(self): async def initialize(self):
pass await self.load_knowledge_bases_from_db()
async def load_knowledge_bases_from_db(self): async def load_knowledge_bases_from_db(self):
self.ap.logger.info('Loading knowledge bases from db...') self.ap.logger.info('Loading knowledge bases from db...')
@@ -183,57 +193,6 @@ class RAGManager:
self.knowledge_bases.remove(kb) self.knowledge_bases.remove(kb)
return return
async def store_data(self, file_path: str, kb_id: str, file_type: str, file_id: str = None):
"""
Parses, chunks, embeds, and stores data from a given file into the RAG system.
Associates the file with a knowledge base using kb_id in the File table.
"""
self.ap.logger.info(f'Starting data storage process for file: {file_path}')
session = SessionLocal()
file_obj = None
try:
kb = session.query(KnowledgeBase).filter_by(id=kb_id).first()
if not kb:
self.ap.logger.info(f'Knowledge Base "{kb_id}" does not exist. ')
return
# get embedding model
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(kb.embedding_model_uuid)
file_name = os.path.basename(file_path)
text = await self.parser.parse(file_path)
if not text:
self.ap.logger.warning(f'No text extracted from file {file_path}. ')
return
chunks_texts = await self.chunker.chunk(text)
self.ap.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.")
await self.embedder.embed_and_store(file_id=file_id, chunks=chunks_texts, embedding_model=embedding_model)
self.ap.logger.info(f'Data storage process completed for file: {file_path}')
except Exception as e:
session.rollback()
self.ap.logger.error(f'Error in store_data for file {file_path}: {str(e)}', exc_info=True)
raise
finally:
if file_id:
file_obj = session.query(File).filter_by(id=file_id).first()
if file_obj:
file_obj.status = 1
session.close()
async def retrieve_data(self, query: str):
"""
Retrieves relevant data chunks based on a given query using the configured retriever.
"""
self.ap.logger.info(f"Starting data retrieval process for query: '{query}'")
try:
retrieved_chunks = await self.retriever.retrieve(query)
self.ap.logger.info(f'Successfully retrieved {len(retrieved_chunks)} chunks for query.')
return retrieved_chunks
except Exception as e:
self.ap.logger.error(f"Error in retrieve_data for query '{query}': {str(e)}", exc_info=True)
return []
async def delete_data_by_file_id(self, file_id: str): async def delete_data_by_file_id(self, file_id: str):
""" """
Deletes all data associated with a specific file ID, including its chunks and vectors, Deletes all data associated with a specific file ID, including its chunks and vectors,
@@ -243,7 +202,7 @@ class RAGManager:
session = SessionLocal() session = SessionLocal()
try: try:
# delete vectors # delete vectors
await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id) await asyncio.to_thread(self.ap.vector_db_mgr.vector_db.delete_by_file_id_sync, file_id)
self.ap.logger.info(f'Deleted embeddings from ChromaDB for file_id: {file_id}') self.ap.logger.info(f'Deleted embeddings from ChromaDB for file_id: {file_id}')
chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all() chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all()
@@ -333,74 +292,18 @@ class RAGManager:
if session.is_active: if session.is_active:
session.close() session.close()
async def get_file_content_by_file_id(self, file_id: str) -> str: # async def get_file_content_by_file_id(self, file_id: str) -> str:
file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id) # file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id)
_, ext = os.path.splitext(file_id.lower()) # _, ext = os.path.splitext(file_id.lower())
ext = ext.lstrip('.') # ext = ext.lstrip('.')
try: # try:
text = file_bytes.decode('utf-8') # text = file_bytes.decode('utf-8')
except UnicodeDecodeError: # except UnicodeDecodeError:
return '[非文本文件或编码无法识别]' # return '[非文本文件或编码无法识别]'
if ext in ['txt', 'md', 'csv', 'log', 'py', 'html']: # if ext in ['txt', 'md', 'csv', 'log', 'py', 'html']:
return text # return text
else: # else:
return f'[未知类型: .{ext}]' # return f'[未知类型: .{ext}]'
async def relate_file_id_with_kb(self, knowledge_base_uuid: str, file_id: str) -> None:
"""
Associates a file with a knowledge base by updating the kb_id in the File table.
"""
self.ap.logger.info(f'Associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}')
session = SessionLocal()
try:
# 查询知识库是否存在
kb = session.query(KnowledgeBase).filter_by(id=knowledge_base_uuid).first()
if not kb:
self.ap.logger.error(f'Knowledge Base with UUID {knowledge_base_uuid} not found.')
return
if not await self.ap.storage_mgr.storage_provider.exists(file_id):
self.ap.logger.error(f'File with ID {file_id} does not exist.')
return
self.ap.logger.info(f'File with ID {file_id} exists, proceeding with association.')
# add new file record
file_to_update = File(
id=file_id,
kb_id=kb.id,
file_name=file_id,
path=os.path.join('data', 'storage', file_id),
file_type=os.path.splitext(file_id)[1].lstrip('.'),
status=0,
)
session.add(file_to_update)
session.commit()
self.ap.logger.info(
f'Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}'
)
except Exception as e:
session.rollback()
self.ap.logger.error(
f'Error associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}: {str(e)}',
exc_info=True,
)
finally:
# 进行文件解析
try:
await self.store_data(
file_path=os.path.join('data', 'storage', file_id),
kb_id=knowledge_base_uuid,
file_type=os.path.splitext(file_id)[1].lstrip('.'),
file_id=file_id,
)
except Exception:
# 如果存储数据时出错,更新文件状态为失败
file_obj = session.query(File).filter_by(id=file_id).first()
if file_obj:
file_obj.status = 2
session.commit()
self.ap.logger.error(f'Error storing data for file ID {file_id}', exc_info=True)
session.close()

View File

@@ -1,26 +1,15 @@
# 封装异步操作 # 封装异步操作
import asyncio import asyncio
import logging
from pkg.rag.knowledge.services.database import SessionLocal
class BaseService: class BaseService:
def __init__(self): def __init__(self):
self.logger = logging.getLogger(self.__class__.__name__) pass
self.db_session_factory = SessionLocal
async def _run_sync(self, func, *args, **kwargs): async def _run_sync(self, func, *args, **kwargs):
""" """
在单独的线程中运行同步函数。 在单独的线程中运行同步函数。
如果第一个参数是 session则在 to_thread 中获取新的 session。 如果第一个参数是 session则在 to_thread 中获取新的 session。
""" """
if getattr(func, '__name__', '').startswith('_db_'): return await asyncio.to_thread(func, *args, **kwargs)
session = await asyncio.to_thread(self.db_session_factory)
try:
result = await asyncio.to_thread(func, session, *args, **kwargs)
return result
finally:
session.close()
else:
# 否则,直接运行同步函数
return await asyncio.to_thread(func, *args, **kwargs)

View File

@@ -1,24 +1,21 @@
# services/chunker.py from __future__ import annotations
import logging
from typing import List from typing import List
from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync from pkg.rag.knowledge.services import base_service
from pkg.core import app from pkg.core import app
logger = logging.getLogger(__name__)
class Chunker(base_service.BaseService):
class Chunker(BaseService):
""" """
A class for splitting long texts into smaller, overlapping chunks. A class for splitting long texts into smaller, overlapping chunks.
""" """
def __init__(self, ap: app.Application, chunk_size: int = 500, chunk_overlap: int = 50): def __init__(self, ap: app.Application, chunk_size: int = 500, chunk_overlap: int = 50):
super().__init__(ap) # Initialize BaseService
self.ap = ap self.ap = ap
self.chunk_size = chunk_size self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap self.chunk_overlap = chunk_overlap
if self.chunk_overlap >= self.chunk_size: if self.chunk_overlap >= self.chunk_size:
self.logger.warning( self.ap.logger.warning(
'Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.' 'Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.'
) )

View File

@@ -1,23 +0,0 @@
# 全部迁移过去
from pkg.entity.persistence.rag import (
create_db_and_tables,
SessionLocal,
Base,
engine,
KnowledgeBase,
File,
Chunk,
Vector,
)
__all__ = [
"create_db_and_tables",
"SessionLocal",
"Base",
"engine",
"KnowledgeBase",
"File",
"Chunk",
"Vector",
]

View File

@@ -1,12 +1,11 @@
from __future__ import annotations from __future__ import annotations
import asyncio import uuid
import numpy as np
from typing import List from typing import List
from sqlalchemy.orm import Session
from pkg.rag.knowledge.services.base_service import BaseService from pkg.rag.knowledge.services.base_service import BaseService
from pkg.rag.knowledge.services.database import Chunk, SessionLocal from ....entity.persistence import rag as persistence_rag
from ....core import app from ....core import app
from ....provider.modelmgr.requester import RuntimeEmbeddingModel from ....provider.modelmgr.requester import RuntimeEmbeddingModel
import sqlalchemy
class Embedder(BaseService): class Embedder(BaseService):
@@ -14,74 +13,41 @@ class Embedder(BaseService):
super().__init__() super().__init__()
self.ap = ap self.ap = ap
def _db_save_chunks_sync(self, session: Session, file_id: int, chunks_texts: List[str]):
"""
Saves chunks to the relational database and returns the created Chunk objects.
This function assumes it's called within a context where the session
will be committed/rolled back and closed by the caller.
"""
self.ap.logger.debug(f'Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).')
chunk_objects = []
for text in chunks_texts:
chunk = Chunk(file_id=file_id, text=text)
session.add(chunk)
chunk_objects.append(chunk)
session.flush() # This populates the .id attribute for each new chunk object
self.ap.logger.debug(f'Successfully added {len(chunk_objects)} chunk entries to DB.')
return chunk_objects
async def embed_and_store( async def embed_and_store(
self, file_id: int, chunks: List[str], embedding_model: RuntimeEmbeddingModel self, kb_id: str, file_id: str, chunks: List[str], embedding_model: RuntimeEmbeddingModel
) -> List[Chunk]: ) -> list[persistence_rag.Chunk]:
session = SessionLocal() # Start a session that will live for the whole operation # save chunk to db
chunk_objects = [] chunk_entities: list[persistence_rag.Chunk] = []
try: chunk_ids: list[str] = []
# 1. Save chunks to the relational database first to get their IDs
# We call _db_save_chunks_sync directly without _run_sync's session management
# because we manage the session here across multiple async calls.
chunk_objects = await asyncio.to_thread(self._db_save_chunks_sync, session, file_id, chunks)
session.commit() # Commit chunks to make their IDs permanent and accessible
if not chunk_objects: for chunk_text in chunks:
self.ap.logger.warning( chunk_uuid = str(uuid.uuid4())
f'No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.' chunk_ids.append(chunk_uuid)
) chunk_entity = persistence_rag.Chunk(uuid=chunk_uuid, file_id=file_id, text=chunk_text)
return [] chunk_entities.append(chunk_entity)
# get the embeddings for the chunks chunk_dicts = [
embeddings: list[list[float]] = [] self.ap.persistence_mgr.serialize_model(persistence_rag.Chunk, chunk) for chunk in chunk_entities
]
for chunk in chunks: await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.Chunk).values(chunk_dicts))
result = await embedding_model.requester.invoke_embedding(
model=embedding_model,
input_text=chunk,
)
embeddings.append(result)
embeddings_np = np.array(embeddings, dtype=np.float32) # get embeddings
embeddings_list: list[list[float]] = await embedding_model.requester.invoke_embedding(
model=embedding_model,
input_text=chunks,
extra_args={}, # TODO: add extra args
)
chunk_ids = [c.id for c in chunk_objects] # save embeddings to vdb
# collection名用kb_idfile对象有kb_id字段 await self._run_sync(
kb_id = session.query(Chunk).filter_by(id=chunk_ids[0]).first().file.kb_id if chunk_ids else None self.ap.vector_db_mgr.vector_db.add_embeddings,
if not kb_id: kb_id,
self.ap.logger.warning('无法获取kb_id向量存储失败') chunk_ids,
return chunk_objects embeddings_list,
chroma_ids = [f'{file_id}_{cid}' for cid in chunk_ids] chunk_dicts,
metadatas = [{'file_id': file_id, 'chunk_id': cid} for cid in chunk_ids] )
await self._run_sync(
self.ap.vector_db_mgr.vector_db.add_embeddings,
kb_id,
chroma_ids,
embeddings_np,
metadatas,
chunks,
)
self.ap.logger.info(f'Successfully saved {len(chunk_objects)} embeddings to VectorDB.')
return chunk_objects
except Exception as e: self.ap.logger.info(f'Successfully saved {len(chunk_entities)} embeddings to Knowledge Base.')
session.rollback() # Rollback on any error
self.ap.logger.error(f'Failed to process and store data for file_id {file_id}: {e}', exc_info=True) return chunk_entities
raise # Re-raise the exception to propagate it
finally:
session.close() # Ensure the session is always closed

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import PyPDF2 import PyPDF2
import io import io
from docx import Document from docx import Document

View File

@@ -1,99 +1,46 @@
from __future__ import annotations from __future__ import annotations
import logging
import numpy as np # Make sure numpy is imported from . import base_service
from typing import List, Dict, Any
from sqlalchemy.orm import Session
from pkg.rag.knowledge.services.base_service import BaseService
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
from pkg.vector.vdb import VectorDatabase
from ....core import app from ....core import app
from ....provider.modelmgr.requester import RuntimeEmbeddingModel
logger = logging.getLogger(__name__) from ....entity.rag import retriever as retriever_entities
class Retriever(BaseService): class Retriever(base_service.BaseService):
def __init__(self, ap: app.Application): def __init__(self, ap: app.Application):
super().__init__() super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.ap = ap self.ap = ap
self.vector_db: VectorDatabase = ap.vector_db_mgr.vector_db
async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]: async def retrieve(
if not self.embedding_model: self, kb_id: str, query: str, embedding_model: RuntimeEmbeddingModel, k: int = 5
raise RuntimeError('Retriever embedding model not loaded. Please check Retriever initialization.') ) -> list[retriever_entities.RetrieveResultEntry]:
self.ap.logger.info(f"Retrieving for query: '{query}' with k={k} using {embedding_model.model_entity.uuid}")
self.logger.info(f"Retrieving for query: '{query}' with k={k} using {self.model_name_key}") query_embedding: list[float] = await embedding_model.requester.invoke_embedding(
model=embedding_model,
input_text=[query],
extra_args={}, # TODO: add extra args
)
query_embedding: List[float] = await self.embedding_model.embed_query(query) chroma_results = await self._run_sync(self.ap.vector_db_mgr.vector_db.search, kb_id, query_embedding[0], k)
query_embedding_np = np.array([query_embedding], dtype=np.float32)
# collection名用kb_id假设retriever有kb_id属性或通过ap传递
kb_id = getattr(self, 'kb_id', None)
if not kb_id:
self.logger.warning('无法获取kb_id向量检索失败')
return []
chroma_results = await self._run_sync(self.vector_db.search, kb_id, query_embedding_np, k)
# 'ids' is always returned by ChromaDB, even if not explicitly in 'include' # 'ids' is always returned by ChromaDB, even if not explicitly in 'include'
matched_chroma_ids = chroma_results.get('ids', [[]])[0] matched_chroma_ids = chroma_results.get('ids', [[]])[0]
distances = chroma_results.get('distances', [[]])[0] distances = chroma_results.get('distances', [[]])[0]
chroma_metadatas = chroma_results.get('metadatas', [[]])[0] chroma_metadatas = chroma_results.get('metadatas', [[]])[0]
chroma_documents = chroma_results.get('documents', [[]])[0]
if not matched_chroma_ids: if not matched_chroma_ids:
self.logger.info('No relevant chunks found in Chroma.') self.ap.logger.info('No relevant chunks found in Chroma.')
return [] return []
db_chunk_ids = [] result: list[retriever_entities.RetrieveResultEntry] = []
for metadata in chroma_metadatas:
if 'chunk_id' in metadata:
db_chunk_ids.append(metadata['chunk_id'])
else:
self.logger.warning(f"Metadata missing 'chunk_id': {metadata}. Skipping this entry.")
if not db_chunk_ids: for i, id in enumerate(matched_chroma_ids):
self.logger.warning('No valid chunk_ids extracted from Chroma results metadata.') entry = retriever_entities.RetrieveResultEntry(
return [] id=id,
metadata=chroma_metadatas[i],
self.logger.info(f'Fetching {len(db_chunk_ids)} chunk details from relational database...') distance=distances[i],
chunks_from_db = await self._run_sync(
lambda cids: self._db_get_chunks_sync(
SessionLocal(), cids
), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync
db_chunk_ids,
)
chunk_map = {chunk.id: chunk for chunk in chunks_from_db}
results_list: List[Dict[str, Any]] = []
for i, chroma_id in enumerate(matched_chroma_ids):
try:
# Ensure original_chunk_id is int for DB lookup
original_chunk_id = int(chroma_id.split('_')[-1])
except (ValueError, IndexError):
self.logger.warning(f'Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.')
continue
chunk_text_from_chroma = chroma_documents[i]
distance = float(distances[i])
file_id_from_chroma = chroma_metadatas[i].get('file_id')
chunk_from_db = chunk_map.get(original_chunk_id)
results_list.append(
{
'chunk_id': original_chunk_id,
'text': chunk_from_db.text if chunk_from_db else chunk_text_from_chroma,
'distance': distance,
'file_id': file_id_from_chroma,
}
) )
result.append(entry)
self.logger.info(f'Retrieved {len(results_list)} chunks.') return result
return results_list
def _db_get_chunks_sync(self, session: Session, chunk_ids: List[int]) -> List[Chunk]:
self.logger.debug(f'Fetching {len(chunk_ids)} chunk details from database (sync).')
chunks = session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all()
session.close()
return chunks

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
from typing import Any, List, Dict from typing import Any, Dict
import numpy as np import numpy as np
@@ -9,10 +9,10 @@ class VectorDatabase(abc.ABC):
def add_embeddings( def add_embeddings(
self, self,
collection: str, collection: str,
ids: List[str], ids: list[str],
embeddings: np.ndarray, embeddings_list: list[list[float]],
metadatas: List[Dict[str, Any]], metadatas: list[dict[str, Any]],
documents: List[str], documents: list[str],
) -> None: ) -> None:
"""向指定 collection 添加向量数据。""" """向指定 collection 添加向量数据。"""
pass pass

View File

@@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
import numpy as np import chromadb
from typing import Any, List, Dict from typing import Any
from chromadb import PersistentClient from chromadb import PersistentClient
from pkg.vector.vdb import VectorDatabase from pkg.vector.vdb import VectorDatabase
from pkg.core import app from pkg.core import app
@@ -12,7 +12,7 @@ class ChromaVectorDatabase(VectorDatabase):
self.client = PersistentClient(path=base_path) self.client = PersistentClient(path=base_path)
self._collections = {} self._collections = {}
def get_or_create_collection(self, collection: str): def get_or_create_collection(self, collection: str) -> chromadb.Collection:
if collection not in self._collections: if collection not in self._collections:
self._collections[collection] = self.client.get_or_create_collection(name=collection) self._collections[collection] = self.client.get_or_create_collection(name=collection)
self.ap.logger.info(f"Chroma collection '{collection}' accessed/created.") self.ap.logger.info(f"Chroma collection '{collection}' accessed/created.")
@@ -21,26 +21,25 @@ class ChromaVectorDatabase(VectorDatabase):
def add_embeddings( def add_embeddings(
self, self,
collection: str, collection: str,
ids: List[str], ids: list[str],
embeddings: np.ndarray, embeddings_list: list[list[float]],
metadatas: List[Dict[str, Any]], metadatas: list[dict[str, Any]],
documents: List[str],
) -> None: ) -> None:
col = self.get_or_create_collection(collection) col = self.get_or_create_collection(collection)
col.add(embeddings=embeddings.tolist(), ids=ids, metadatas=metadatas, documents=documents) col.add(embeddings=embeddings_list, ids=ids, metadatas=metadatas)
self.ap.logger.info(f"Added {len(ids)} embeddings to Chroma collection '{collection}'.") self.ap.logger.info(f"Added {len(ids)} embeddings to Chroma collection '{collection}'.")
def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]: def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]:
col = self.get_or_create_collection(collection) col = self.get_or_create_collection(collection)
results = col.query( results = col.query(
query_embeddings=query_embedding.tolist(), query_embeddings=query_embedding,
n_results=k, n_results=k,
include=['metadatas', 'distances', 'documents'], include=['metadatas', 'distances', 'documents'],
) )
self.ap.logger.debug(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.") self.ap.logger.info(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.")
return results return results
def delete_by_metadata(self, collection: str, where: Dict[str, Any]) -> None: def delete_by_metadata(self, collection: str, where: dict[str, Any]) -> None:
col = self.get_or_create_collection(collection) col = self.get_or_create_collection(collection)
col.delete(where=where) col.delete(where=where)
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with filter: {where}") self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with filter: {where}")