From 2f2db4d445d3d397e29b23a59ca103a24c82e395 Mon Sep 17 00:00:00 2001 From: Junyan Qin Date: Wed, 16 Jul 2025 21:17:18 +0800 Subject: [PATCH] feat(rag): make embedding and retrieving available --- pkg/api/http/controller/group.py | 4 +- pkg/api/http/controller/groups/files.py | 3 +- .../http/controller/groups/knowledge/base.py | 10 + pkg/api/http/service/knowledge.py | 7 + pkg/api/http/service/model.py | 2 +- pkg/core/app.py | 2 +- pkg/core/stages/build_app.py | 4 +- pkg/entity/rag/__init__.py | 0 pkg/entity/rag/retriever.py | 13 ++ pkg/provider/modelmgr/requester.py | 8 +- pkg/provider/modelmgr/requesters/chatcmpl.py | 7 +- pkg/rag/knowledge/{mgr.py => kbmgr.py} | 195 +++++------------- pkg/rag/knowledge/services/base_service.py | 19 +- pkg/rag/knowledge/services/chunker.py | 13 +- pkg/rag/knowledge/services/database.py | 23 --- pkg/rag/knowledge/services/embedder.py | 102 +++------ pkg/rag/knowledge/services/parser.py | 2 + pkg/rag/knowledge/services/retriever.py | 101 +++------ pkg/vector/vdb.py | 10 +- pkg/vector/vdbs/chroma.py | 23 +-- 20 files changed, 180 insertions(+), 368 deletions(-) create mode 100644 pkg/entity/rag/__init__.py create mode 100644 pkg/entity/rag/retriever.py rename pkg/rag/knowledge/{mgr.py => kbmgr.py} (60%) diff --git a/pkg/api/http/controller/group.py b/pkg/api/http/controller/group.py index 3f34d79b..2088ecc1 100644 --- a/pkg/api/http/controller/group.py +++ b/pkg/api/http/controller/group.py @@ -86,10 +86,10 @@ class RouterGroup(abc.ABC): try: return await f(*args, **kwargs) - except Exception: # 自动 500 + except Exception as e: # 自动 500 traceback.print_exc() # return self.http_status(500, -2, str(e)) - return self.http_status(500, -2, 'internal server error') + return self.http_status(500, -2, str(e)) new_f = handler_error new_f.__name__ = (self.name + rule).replace('/', '__') diff --git a/pkg/api/http/controller/groups/files.py b/pkg/api/http/controller/groups/files.py index d08cbd71..b3c1a3f1 100644 --- a/pkg/api/http/controller/groups/files.py +++ b/pkg/api/http/controller/groups/files.py @@ -34,8 +34,9 @@ class FilesRouterGroup(group.RouterGroup): file_bytes = await asyncio.to_thread(file.stream.read) extension = file.filename.split('.')[-1] + file_name = file.filename.split('.')[0] - file_key = str(uuid.uuid4()) + '.' + extension + file_key = file_name + '_' + str(uuid.uuid4())[:8] + '.' + extension # save file to storage await self.ap.storage_mgr.storage_provider.save(file_key, file_bytes) return self.success( diff --git a/pkg/api/http/controller/groups/knowledge/base.py b/pkg/api/http/controller/groups/knowledge/base.py index 866b4af2..5fd80cbd 100644 --- a/pkg/api/http/controller/groups/knowledge/base.py +++ b/pkg/api/http/controller/groups/knowledge/base.py @@ -72,3 +72,13 @@ class KnowledgeBaseRouterGroup(group.RouterGroup): async def delete_specific_file_in_kb(file_id: str, knowledge_base_uuid: str) -> str: await self.ap.knowledge_service.delete_file(knowledge_base_uuid, file_id) return self.success({}) + + @self.route( + '//retrieve', + methods=['POST'], + ) + async def retrieve_knowledge_base(knowledge_base_uuid: str) -> str: + json_data = await quart.request.json + query = json_data.get('query') + results = await self.ap.knowledge_service.retrieve_knowledge_base(knowledge_base_uuid, query) + return self.success(data={'results': results}) diff --git a/pkg/api/http/service/knowledge.py b/pkg/api/http/service/knowledge.py index 5d702ba4..e42a14a7 100644 --- a/pkg/api/http/service/knowledge.py +++ b/pkg/api/http/service/knowledge.py @@ -67,6 +67,13 @@ class KnowledgeService: raise Exception('Knowledge base not found') return await runtime_kb.store_file(file_id) + async def retrieve_knowledge_base(self, kb_uuid: str, query: str) -> list[dict]: + """检索知识库""" + runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid) + if runtime_kb is None: + raise Exception('Knowledge base not found') + return [result.model_dump() for result in await runtime_kb.retrieve(query)] + async def get_files_by_knowledge_base(self, kb_uuid: str) -> list[dict]: """获取知识库文件""" result = await self.ap.persistence_mgr.execute_async( diff --git a/pkg/api/http/service/model.py b/pkg/api/http/service/model.py index 3a4998e2..d8457da3 100644 --- a/pkg/api/http/service/model.py +++ b/pkg/api/http/service/model.py @@ -186,6 +186,6 @@ class EmbeddingModelsService: await runtime_embedding_model.requester.invoke_embedding( model=runtime_embedding_model, - input_text='Hello, world!', + input_text=['Hello, world!'], extra_args={}, ) diff --git a/pkg/core/app.py b/pkg/core/app.py index 092676c6..f0c30aee 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -28,7 +28,7 @@ from ..storage import mgr as storagemgr from ..utils import logcache from . import taskmgr from . import entities as core_entities -from ..rag.knowledge import mgr as rag_mgr +from ..rag.knowledge import kbmgr as rag_mgr from ..vector import mgr as vectordb_mgr diff --git a/pkg/core/stages/build_app.py b/pkg/core/stages/build_app.py index d9521274..e48fdd99 100644 --- a/pkg/core/stages/build_app.py +++ b/pkg/core/stages/build_app.py @@ -9,7 +9,7 @@ from ...command import cmdmgr from ...provider.session import sessionmgr as llm_session_mgr from ...provider.modelmgr import modelmgr as llm_model_mgr from ...provider.tools import toolmgr as llm_tool_mgr -from ...rag.knowledge import mgr as rag_mgr +from ...rag.knowledge import kbmgr as rag_mgr from ...platform import botmgr as im_mgr from ...persistence import mgr as persistencemgr from ...api.http.controller import main as http_controller @@ -92,7 +92,7 @@ class BuildAppStage(stage.BootingStage): ap.pipeline_mgr = pipeline_mgr rag_mgr_inst = rag_mgr.RAGManager(ap) - await rag_mgr_inst.initialize_rag_system() + await rag_mgr_inst.initialize() ap.rag_mgr = rag_mgr_inst # 初始化向量数据库管理器 diff --git a/pkg/entity/rag/__init__.py b/pkg/entity/rag/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/entity/rag/retriever.py b/pkg/entity/rag/retriever.py new file mode 100644 index 00000000..becaf8db --- /dev/null +++ b/pkg/entity/rag/retriever.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import pydantic + +from typing import Any + + +class RetrieveResultEntry(pydantic.BaseModel): + id: str + + metadata: dict[str, Any] + + distance: float diff --git a/pkg/provider/modelmgr/requester.py b/pkg/provider/modelmgr/requester.py index 0abebfa5..17697cdb 100644 --- a/pkg/provider/modelmgr/requester.py +++ b/pkg/provider/modelmgr/requester.py @@ -101,18 +101,18 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta): async def invoke_embedding( self, model: RuntimeEmbeddingModel, - input_text: str, + input_text: list[str], extra_args: dict[str, typing.Any] = {}, - ) -> list[float]: + ) -> list[list[float]]: """调用 Embedding API Args: query (core_entities.Query): 请求上下文 model (RuntimeEmbeddingModel): 使用的模型信息 - input_text (str): 输入文本 + input_text (list[str]): 输入文本 extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}. Returns: - list[float]: 返回的 embedding 向量 + list[list[float]]: 返回的 embedding 向量 """ pass diff --git a/pkg/provider/modelmgr/requesters/chatcmpl.py b/pkg/provider/modelmgr/requesters/chatcmpl.py index 5dadab7d..aaaf3751 100644 --- a/pkg/provider/modelmgr/requesters/chatcmpl.py +++ b/pkg/provider/modelmgr/requesters/chatcmpl.py @@ -145,9 +145,9 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): async def invoke_embedding( self, model: requester.RuntimeEmbeddingModel, - input_text: str, + input_text: list[str], extra_args: dict[str, typing.Any] = {}, - ) -> list[float]: + ) -> list[list[float]]: """调用 Embedding API""" self.client.api_key = model.token_mgr.get_token() @@ -163,7 +163,8 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester): try: resp = await self.client.embeddings.create(**args) - return resp.data[0].embedding + + return [d.embedding for d in resp.data] except asyncio.TimeoutError: raise errors.RequesterError('请求超时') except openai.BadRequestError as e: diff --git a/pkg/rag/knowledge/mgr.py b/pkg/rag/knowledge/kbmgr.py similarity index 60% rename from pkg/rag/knowledge/mgr.py rename to pkg/rag/knowledge/kbmgr.py index 6e5fe366..46be7f75 100644 --- a/pkg/rag/knowledge/mgr.py +++ b/pkg/rag/knowledge/kbmgr.py @@ -1,21 +1,15 @@ from __future__ import annotations -import os import asyncio import traceback import uuid -from pkg.rag.knowledge.services.parser import FileParser -from pkg.rag.knowledge.services.chunker import Chunker -from pkg.rag.knowledge.services.database import ( - KnowledgeBase, - File, - Chunk, -) +from .services import parser, chunker from pkg.core import app from pkg.rag.knowledge.services.embedder import Embedder from pkg.rag.knowledge.services.retriever import Retriever import sqlalchemy from ...entity.persistence import rag as persistence_rag from pkg.core import taskmgr +from ...entity.rag import retriever as retriever_entities class RuntimeKnowledgeBase: @@ -23,9 +17,9 @@ class RuntimeKnowledgeBase: knowledge_base_entity: persistence_rag.KnowledgeBase - parser: FileParser + parser: parser.FileParser - chunker: Chunker + chunker: chunker.Chunker embedder: Embedder @@ -34,8 +28,8 @@ class RuntimeKnowledgeBase: def __init__(self, ap: app.Application, knowledge_base_entity: persistence_rag.KnowledgeBase): self.ap = ap self.knowledge_base_entity = knowledge_base_entity - self.parser = FileParser(ap=self.ap) - self.chunker = Chunker(ap=self.ap) + self.parser = parser.FileParser(ap=self.ap) + self.chunker = chunker.Chunker(ap=self.ap) self.embedder = Embedder(ap=self.ap) self.retriever = Retriever(ap=self.ap) # 传递kb_id给retriever @@ -66,9 +60,16 @@ class RuntimeKnowledgeBase: raise Exception(f'No chunks extracted from file {file.file_name}') task_context.set_current_action('Embedding chunks') + + embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid( + self.knowledge_base_entity.embedding_model_uuid + ) # embed chunks await self.embedder.embed_and_store( - file_id=file.uuid, chunks=chunks_texts, embedding_model=self.knowledge_base_entity.embedding_model_uuid + kb_id=self.knowledge_base_entity.uuid, + file_id=file.uuid, + chunks=chunks_texts, + embedding_model=embedding_model, ) # set file status to completed @@ -79,7 +80,8 @@ class RuntimeKnowledgeBase: ) except Exception as e: - self.ap.logger.error(f'Error storing file {file.file_id}: {e}') + self.ap.logger.error(f'Error storing file {file.uuid}: {e}') + traceback.print_exc() # set file status to failed await self.ap.persistence_mgr.execute_async( sqlalchemy.update(persistence_rag.File) @@ -89,7 +91,7 @@ class RuntimeKnowledgeBase: raise - async def store_file(self, file_id: str) -> int: + async def store_file(self, file_id: str) -> str: # pre checking if not await self.ap.storage_mgr.storage_provider.exists(file_id): raise Exception(f'File {file_id} not found') @@ -97,22 +99,24 @@ class RuntimeKnowledgeBase: file_uuid = str(uuid.uuid4()) kb_id = self.knowledge_base_entity.uuid file_name = file_id - extension = os.path.splitext(file_id)[1].lstrip('.') + extension = file_name.split('.')[-1] - file = persistence_rag.File( - uuid=file_uuid, - kb_id=kb_id, - file_name=file_name, - extension=extension, - status='pending', - ) + file_obj_data = { + 'uuid': file_uuid, + 'kb_id': kb_id, + 'file_name': file_name, + 'extension': extension, + 'status': 'pending', + } - await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(**file.to_dict())) + file_obj = persistence_rag.File(**file_obj_data) + + await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(file_obj_data)) # run background task asynchronously ctx = taskmgr.TaskContext.new() wrapper = self.ap.task_mgr.create_user_task( - self._store_file_task(file, task_context=ctx), + self._store_file_task(file_obj, task_context=ctx), kind='knowledge-operation', name=f'knowledge-store-file-{file_id}', label=f'Store file {file_id}', @@ -120,6 +124,12 @@ class RuntimeKnowledgeBase: ) return wrapper.id + async def retrieve(self, query: str) -> list[retriever_entities.RetrieveResultEntry]: + embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid( + self.knowledge_base_entity.embedding_model_uuid + ) + return await self.retriever.retrieve(self.knowledge_base_entity.uuid, query, embedding_model) + async def dispose(self): pass @@ -134,7 +144,7 @@ class RAGManager: self.knowledge_bases = [] async def initialize(self): - pass + await self.load_knowledge_bases_from_db() async def load_knowledge_bases_from_db(self): self.ap.logger.info('Loading knowledge bases from db...') @@ -183,57 +193,6 @@ class RAGManager: self.knowledge_bases.remove(kb) return - async def store_data(self, file_path: str, kb_id: str, file_type: str, file_id: str = None): - """ - Parses, chunks, embeds, and stores data from a given file into the RAG system. - Associates the file with a knowledge base using kb_id in the File table. - """ - self.ap.logger.info(f'Starting data storage process for file: {file_path}') - session = SessionLocal() - file_obj = None - - try: - kb = session.query(KnowledgeBase).filter_by(id=kb_id).first() - if not kb: - self.ap.logger.info(f'Knowledge Base "{kb_id}" does not exist. ') - return - # get embedding model - embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(kb.embedding_model_uuid) - file_name = os.path.basename(file_path) - text = await self.parser.parse(file_path) - if not text: - self.ap.logger.warning(f'No text extracted from file {file_path}. ') - return - - chunks_texts = await self.chunker.chunk(text) - self.ap.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.") - await self.embedder.embed_and_store(file_id=file_id, chunks=chunks_texts, embedding_model=embedding_model) - self.ap.logger.info(f'Data storage process completed for file: {file_path}') - - except Exception as e: - session.rollback() - self.ap.logger.error(f'Error in store_data for file {file_path}: {str(e)}', exc_info=True) - raise - finally: - if file_id: - file_obj = session.query(File).filter_by(id=file_id).first() - if file_obj: - file_obj.status = 1 - session.close() - - async def retrieve_data(self, query: str): - """ - Retrieves relevant data chunks based on a given query using the configured retriever. - """ - self.ap.logger.info(f"Starting data retrieval process for query: '{query}'") - try: - retrieved_chunks = await self.retriever.retrieve(query) - self.ap.logger.info(f'Successfully retrieved {len(retrieved_chunks)} chunks for query.') - return retrieved_chunks - except Exception as e: - self.ap.logger.error(f"Error in retrieve_data for query '{query}': {str(e)}", exc_info=True) - return [] - async def delete_data_by_file_id(self, file_id: str): """ Deletes all data associated with a specific file ID, including its chunks and vectors, @@ -243,7 +202,7 @@ class RAGManager: session = SessionLocal() try: # delete vectors - await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id) + await asyncio.to_thread(self.ap.vector_db_mgr.vector_db.delete_by_file_id_sync, file_id) self.ap.logger.info(f'Deleted embeddings from ChromaDB for file_id: {file_id}') chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all() @@ -333,74 +292,18 @@ class RAGManager: if session.is_active: session.close() - async def get_file_content_by_file_id(self, file_id: str) -> str: - file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id) + # async def get_file_content_by_file_id(self, file_id: str) -> str: + # file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id) - _, ext = os.path.splitext(file_id.lower()) - ext = ext.lstrip('.') + # _, ext = os.path.splitext(file_id.lower()) + # ext = ext.lstrip('.') - try: - text = file_bytes.decode('utf-8') - except UnicodeDecodeError: - return '[非文本文件或编码无法识别]' + # try: + # text = file_bytes.decode('utf-8') + # except UnicodeDecodeError: + # return '[非文本文件或编码无法识别]' - if ext in ['txt', 'md', 'csv', 'log', 'py', 'html']: - return text - else: - return f'[未知类型: .{ext}]' - - async def relate_file_id_with_kb(self, knowledge_base_uuid: str, file_id: str) -> None: - """ - Associates a file with a knowledge base by updating the kb_id in the File table. - """ - self.ap.logger.info(f'Associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}') - session = SessionLocal() - try: - # 查询知识库是否存在 - kb = session.query(KnowledgeBase).filter_by(id=knowledge_base_uuid).first() - if not kb: - self.ap.logger.error(f'Knowledge Base with UUID {knowledge_base_uuid} not found.') - return - - if not await self.ap.storage_mgr.storage_provider.exists(file_id): - self.ap.logger.error(f'File with ID {file_id} does not exist.') - return - self.ap.logger.info(f'File with ID {file_id} exists, proceeding with association.') - # add new file record - file_to_update = File( - id=file_id, - kb_id=kb.id, - file_name=file_id, - path=os.path.join('data', 'storage', file_id), - file_type=os.path.splitext(file_id)[1].lstrip('.'), - status=0, - ) - session.add(file_to_update) - session.commit() - self.ap.logger.info( - f'Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}' - ) - except Exception as e: - session.rollback() - self.ap.logger.error( - f'Error associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}: {str(e)}', - exc_info=True, - ) - finally: - # 进行文件解析 - try: - await self.store_data( - file_path=os.path.join('data', 'storage', file_id), - kb_id=knowledge_base_uuid, - file_type=os.path.splitext(file_id)[1].lstrip('.'), - file_id=file_id, - ) - except Exception: - # 如果存储数据时出错,更新文件状态为失败 - file_obj = session.query(File).filter_by(id=file_id).first() - if file_obj: - file_obj.status = 2 - session.commit() - self.ap.logger.error(f'Error storing data for file ID {file_id}', exc_info=True) - - session.close() + # if ext in ['txt', 'md', 'csv', 'log', 'py', 'html']: + # return text + # else: + # return f'[未知类型: .{ext}]' diff --git a/pkg/rag/knowledge/services/base_service.py b/pkg/rag/knowledge/services/base_service.py index 4ff1ce39..0f71a508 100644 --- a/pkg/rag/knowledge/services/base_service.py +++ b/pkg/rag/knowledge/services/base_service.py @@ -1,26 +1,15 @@ # 封装异步操作 import asyncio -import logging -from pkg.rag.knowledge.services.database import SessionLocal + class BaseService: def __init__(self): - self.logger = logging.getLogger(self.__class__.__name__) - self.db_session_factory = SessionLocal + pass async def _run_sync(self, func, *args, **kwargs): """ 在单独的线程中运行同步函数。 如果第一个参数是 session,则在 to_thread 中获取新的 session。 """ - - if getattr(func, '__name__', '').startswith('_db_'): - session = await asyncio.to_thread(self.db_session_factory) - try: - result = await asyncio.to_thread(func, session, *args, **kwargs) - return result - finally: - session.close() - else: - # 否则,直接运行同步函数 - return await asyncio.to_thread(func, *args, **kwargs) \ No newline at end of file + + return await asyncio.to_thread(func, *args, **kwargs) diff --git a/pkg/rag/knowledge/services/chunker.py b/pkg/rag/knowledge/services/chunker.py index 93b10a55..9aa1810b 100644 --- a/pkg/rag/knowledge/services/chunker.py +++ b/pkg/rag/knowledge/services/chunker.py @@ -1,24 +1,21 @@ -# services/chunker.py -import logging +from __future__ import annotations + from typing import List -from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync +from pkg.rag.knowledge.services import base_service from pkg.core import app -logger = logging.getLogger(__name__) - -class Chunker(BaseService): +class Chunker(base_service.BaseService): """ A class for splitting long texts into smaller, overlapping chunks. """ def __init__(self, ap: app.Application, chunk_size: int = 500, chunk_overlap: int = 50): - super().__init__(ap) # Initialize BaseService self.ap = ap self.chunk_size = chunk_size self.chunk_overlap = chunk_overlap if self.chunk_overlap >= self.chunk_size: - self.logger.warning( + self.ap.logger.warning( 'Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.' ) diff --git a/pkg/rag/knowledge/services/database.py b/pkg/rag/knowledge/services/database.py index bc5caa10..e69de29b 100644 --- a/pkg/rag/knowledge/services/database.py +++ b/pkg/rag/knowledge/services/database.py @@ -1,23 +0,0 @@ -# 全部迁移过去 - -from pkg.entity.persistence.rag import ( - create_db_and_tables, - SessionLocal, - Base, - engine, - KnowledgeBase, - File, - Chunk, - Vector, -) - -__all__ = [ - "create_db_and_tables", - "SessionLocal", - "Base", - "engine", - "KnowledgeBase", - "File", - "Chunk", - "Vector", -] diff --git a/pkg/rag/knowledge/services/embedder.py b/pkg/rag/knowledge/services/embedder.py index 213896a1..6b019433 100644 --- a/pkg/rag/knowledge/services/embedder.py +++ b/pkg/rag/knowledge/services/embedder.py @@ -1,12 +1,11 @@ from __future__ import annotations -import asyncio -import numpy as np +import uuid from typing import List -from sqlalchemy.orm import Session from pkg.rag.knowledge.services.base_service import BaseService -from pkg.rag.knowledge.services.database import Chunk, SessionLocal +from ....entity.persistence import rag as persistence_rag from ....core import app from ....provider.modelmgr.requester import RuntimeEmbeddingModel +import sqlalchemy class Embedder(BaseService): @@ -14,74 +13,41 @@ class Embedder(BaseService): super().__init__() self.ap = ap - def _db_save_chunks_sync(self, session: Session, file_id: int, chunks_texts: List[str]): - """ - Saves chunks to the relational database and returns the created Chunk objects. - This function assumes it's called within a context where the session - will be committed/rolled back and closed by the caller. - """ - self.ap.logger.debug(f'Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).') - chunk_objects = [] - for text in chunks_texts: - chunk = Chunk(file_id=file_id, text=text) - session.add(chunk) - chunk_objects.append(chunk) - session.flush() # This populates the .id attribute for each new chunk object - self.ap.logger.debug(f'Successfully added {len(chunk_objects)} chunk entries to DB.') - return chunk_objects - async def embed_and_store( - self, file_id: int, chunks: List[str], embedding_model: RuntimeEmbeddingModel - ) -> List[Chunk]: - session = SessionLocal() # Start a session that will live for the whole operation - chunk_objects = [] - try: - # 1. Save chunks to the relational database first to get their IDs - # We call _db_save_chunks_sync directly without _run_sync's session management - # because we manage the session here across multiple async calls. - chunk_objects = await asyncio.to_thread(self._db_save_chunks_sync, session, file_id, chunks) - session.commit() # Commit chunks to make their IDs permanent and accessible + self, kb_id: str, file_id: str, chunks: List[str], embedding_model: RuntimeEmbeddingModel + ) -> list[persistence_rag.Chunk]: + # save chunk to db + chunk_entities: list[persistence_rag.Chunk] = [] + chunk_ids: list[str] = [] - if not chunk_objects: - self.ap.logger.warning( - f'No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.' - ) - return [] + for chunk_text in chunks: + chunk_uuid = str(uuid.uuid4()) + chunk_ids.append(chunk_uuid) + chunk_entity = persistence_rag.Chunk(uuid=chunk_uuid, file_id=file_id, text=chunk_text) + chunk_entities.append(chunk_entity) - # get the embeddings for the chunks - embeddings: list[list[float]] = [] + chunk_dicts = [ + self.ap.persistence_mgr.serialize_model(persistence_rag.Chunk, chunk) for chunk in chunk_entities + ] - for chunk in chunks: - result = await embedding_model.requester.invoke_embedding( - model=embedding_model, - input_text=chunk, - ) - embeddings.append(result) + await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.Chunk).values(chunk_dicts)) - embeddings_np = np.array(embeddings, dtype=np.float32) + # get embeddings + embeddings_list: list[list[float]] = await embedding_model.requester.invoke_embedding( + model=embedding_model, + input_text=chunks, + extra_args={}, # TODO: add extra args + ) - chunk_ids = [c.id for c in chunk_objects] - # collection名用kb_id(file对象有kb_id字段) - kb_id = session.query(Chunk).filter_by(id=chunk_ids[0]).first().file.kb_id if chunk_ids else None - if not kb_id: - self.ap.logger.warning('无法获取kb_id,向量存储失败') - return chunk_objects - chroma_ids = [f'{file_id}_{cid}' for cid in chunk_ids] - metadatas = [{'file_id': file_id, 'chunk_id': cid} for cid in chunk_ids] - await self._run_sync( - self.ap.vector_db_mgr.vector_db.add_embeddings, - kb_id, - chroma_ids, - embeddings_np, - metadatas, - chunks, - ) - self.ap.logger.info(f'Successfully saved {len(chunk_objects)} embeddings to VectorDB.') - return chunk_objects + # save embeddings to vdb + await self._run_sync( + self.ap.vector_db_mgr.vector_db.add_embeddings, + kb_id, + chunk_ids, + embeddings_list, + chunk_dicts, + ) - except Exception as e: - session.rollback() # Rollback on any error - self.ap.logger.error(f'Failed to process and store data for file_id {file_id}: {e}', exc_info=True) - raise # Re-raise the exception to propagate it - finally: - session.close() # Ensure the session is always closed + self.ap.logger.info(f'Successfully saved {len(chunk_entities)} embeddings to Knowledge Base.') + + return chunk_entities diff --git a/pkg/rag/knowledge/services/parser.py b/pkg/rag/knowledge/services/parser.py index 91b4f9ff..6a683aa5 100644 --- a/pkg/rag/knowledge/services/parser.py +++ b/pkg/rag/knowledge/services/parser.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import PyPDF2 import io from docx import Document diff --git a/pkg/rag/knowledge/services/retriever.py b/pkg/rag/knowledge/services/retriever.py index 3385021a..fc403a57 100644 --- a/pkg/rag/knowledge/services/retriever.py +++ b/pkg/rag/knowledge/services/retriever.py @@ -1,99 +1,46 @@ from __future__ import annotations -import logging -import numpy as np # Make sure numpy is imported -from typing import List, Dict, Any -from sqlalchemy.orm import Session -from pkg.rag.knowledge.services.base_service import BaseService -from pkg.rag.knowledge.services.database import Chunk, SessionLocal -from pkg.vector.vdb import VectorDatabase + +from . import base_service from ....core import app - -logger = logging.getLogger(__name__) +from ....provider.modelmgr.requester import RuntimeEmbeddingModel +from ....entity.rag import retriever as retriever_entities -class Retriever(BaseService): +class Retriever(base_service.BaseService): def __init__(self, ap: app.Application): super().__init__() - self.logger = logging.getLogger(self.__class__.__name__) self.ap = ap - self.vector_db: VectorDatabase = ap.vector_db_mgr.vector_db - async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]: - if not self.embedding_model: - raise RuntimeError('Retriever embedding model not loaded. Please check Retriever initialization.') + async def retrieve( + self, kb_id: str, query: str, embedding_model: RuntimeEmbeddingModel, k: int = 5 + ) -> list[retriever_entities.RetrieveResultEntry]: + self.ap.logger.info(f"Retrieving for query: '{query}' with k={k} using {embedding_model.model_entity.uuid}") - self.logger.info(f"Retrieving for query: '{query}' with k={k} using {self.model_name_key}") + query_embedding: list[float] = await embedding_model.requester.invoke_embedding( + model=embedding_model, + input_text=[query], + extra_args={}, # TODO: add extra args + ) - query_embedding: List[float] = await self.embedding_model.embed_query(query) - query_embedding_np = np.array([query_embedding], dtype=np.float32) - - # collection名用kb_id(假设retriever有kb_id属性或通过ap传递) - kb_id = getattr(self, 'kb_id', None) - if not kb_id: - self.logger.warning('无法获取kb_id,向量检索失败') - return [] - chroma_results = await self._run_sync(self.vector_db.search, kb_id, query_embedding_np, k) + chroma_results = await self._run_sync(self.ap.vector_db_mgr.vector_db.search, kb_id, query_embedding[0], k) # 'ids' is always returned by ChromaDB, even if not explicitly in 'include' matched_chroma_ids = chroma_results.get('ids', [[]])[0] distances = chroma_results.get('distances', [[]])[0] chroma_metadatas = chroma_results.get('metadatas', [[]])[0] - chroma_documents = chroma_results.get('documents', [[]])[0] if not matched_chroma_ids: - self.logger.info('No relevant chunks found in Chroma.') + self.ap.logger.info('No relevant chunks found in Chroma.') return [] - db_chunk_ids = [] - for metadata in chroma_metadatas: - if 'chunk_id' in metadata: - db_chunk_ids.append(metadata['chunk_id']) - else: - self.logger.warning(f"Metadata missing 'chunk_id': {metadata}. Skipping this entry.") + result: list[retriever_entities.RetrieveResultEntry] = [] - if not db_chunk_ids: - self.logger.warning('No valid chunk_ids extracted from Chroma results metadata.') - return [] - - self.logger.info(f'Fetching {len(db_chunk_ids)} chunk details from relational database...') - chunks_from_db = await self._run_sync( - lambda cids: self._db_get_chunks_sync( - SessionLocal(), cids - ), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync - db_chunk_ids, - ) - - chunk_map = {chunk.id: chunk for chunk in chunks_from_db} - results_list: List[Dict[str, Any]] = [] - - for i, chroma_id in enumerate(matched_chroma_ids): - try: - # Ensure original_chunk_id is int for DB lookup - original_chunk_id = int(chroma_id.split('_')[-1]) - except (ValueError, IndexError): - self.logger.warning(f'Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.') - continue - - chunk_text_from_chroma = chroma_documents[i] - distance = float(distances[i]) - file_id_from_chroma = chroma_metadatas[i].get('file_id') - - chunk_from_db = chunk_map.get(original_chunk_id) - - results_list.append( - { - 'chunk_id': original_chunk_id, - 'text': chunk_from_db.text if chunk_from_db else chunk_text_from_chroma, - 'distance': distance, - 'file_id': file_id_from_chroma, - } + for i, id in enumerate(matched_chroma_ids): + entry = retriever_entities.RetrieveResultEntry( + id=id, + metadata=chroma_metadatas[i], + distance=distances[i], ) + result.append(entry) - self.logger.info(f'Retrieved {len(results_list)} chunks.') - return results_list - - def _db_get_chunks_sync(self, session: Session, chunk_ids: List[int]) -> List[Chunk]: - self.logger.debug(f'Fetching {len(chunk_ids)} chunk details from database (sync).') - chunks = session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all() - session.close() - return chunks + return result diff --git a/pkg/vector/vdb.py b/pkg/vector/vdb.py index 20eff831..2b7ca400 100644 --- a/pkg/vector/vdb.py +++ b/pkg/vector/vdb.py @@ -1,6 +1,6 @@ from __future__ import annotations import abc -from typing import Any, List, Dict +from typing import Any, Dict import numpy as np @@ -9,10 +9,10 @@ class VectorDatabase(abc.ABC): def add_embeddings( self, collection: str, - ids: List[str], - embeddings: np.ndarray, - metadatas: List[Dict[str, Any]], - documents: List[str], + ids: list[str], + embeddings_list: list[list[float]], + metadatas: list[dict[str, Any]], + documents: list[str], ) -> None: """向指定 collection 添加向量数据。""" pass diff --git a/pkg/vector/vdbs/chroma.py b/pkg/vector/vdbs/chroma.py index c249c0ba..8f295931 100644 --- a/pkg/vector/vdbs/chroma.py +++ b/pkg/vector/vdbs/chroma.py @@ -1,6 +1,6 @@ from __future__ import annotations -import numpy as np -from typing import Any, List, Dict +import chromadb +from typing import Any from chromadb import PersistentClient from pkg.vector.vdb import VectorDatabase from pkg.core import app @@ -12,7 +12,7 @@ class ChromaVectorDatabase(VectorDatabase): self.client = PersistentClient(path=base_path) self._collections = {} - def get_or_create_collection(self, collection: str): + def get_or_create_collection(self, collection: str) -> chromadb.Collection: if collection not in self._collections: self._collections[collection] = self.client.get_or_create_collection(name=collection) self.ap.logger.info(f"Chroma collection '{collection}' accessed/created.") @@ -21,26 +21,25 @@ class ChromaVectorDatabase(VectorDatabase): def add_embeddings( self, collection: str, - ids: List[str], - embeddings: np.ndarray, - metadatas: List[Dict[str, Any]], - documents: List[str], + ids: list[str], + embeddings_list: list[list[float]], + metadatas: list[dict[str, Any]], ) -> None: col = self.get_or_create_collection(collection) - col.add(embeddings=embeddings.tolist(), ids=ids, metadatas=metadatas, documents=documents) + col.add(embeddings=embeddings_list, ids=ids, metadatas=metadatas) self.ap.logger.info(f"Added {len(ids)} embeddings to Chroma collection '{collection}'.") - def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]: + def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]: col = self.get_or_create_collection(collection) results = col.query( - query_embeddings=query_embedding.tolist(), + query_embeddings=query_embedding, n_results=k, include=['metadatas', 'distances', 'documents'], ) - self.ap.logger.debug(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.") + self.ap.logger.info(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.") return results - def delete_by_metadata(self, collection: str, where: Dict[str, Any]) -> None: + def delete_by_metadata(self, collection: str, where: dict[str, Any]) -> None: col = self.get_or_create_collection(collection) col.delete(where=where) self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with filter: {where}")