diff --git a/.gitignore b/.gitignore index 2869b7cc..db62bdca 100644 --- a/.gitignore +++ b/.gitignore @@ -42,4 +42,5 @@ botpy.log* test.py /web_ui .venv/ -uv.lock \ No newline at end of file +uv.lock +/test \ No newline at end of file diff --git a/pkg/entity/persistence/rag.py b/pkg/entity/persistence/rag.py index 95a78712..9ca84741 100644 --- a/pkg/entity/persistence/rag.py +++ b/pkg/entity/persistence/rag.py @@ -2,6 +2,7 @@ from sqlalchemy import create_engine, Column, String, Text, DateTime, LargeBinar from sqlalchemy.orm import declarative_base, sessionmaker from datetime import datetime import os +import uuid Base = declarative_base() DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./rag_knowledge.db') @@ -35,11 +36,11 @@ class File(Base): path = Column(String) created_at = Column(DateTime, default=datetime.utcnow) file_type = Column(String) - status = Column(String, default='0') + status = Column(Integer, default=0) # 0: uploaded and processing, 1: completed, 2: failed class Chunk(Base): __tablename__ = 'chunks' - id = Column(String, primary_key=True, index=True) + id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4())) file_id = Column(String, nullable=True) text = Column(Text) diff --git a/pkg/rag/knowledge/mgr.py b/pkg/rag/knowledge/mgr.py index 89e5b393..09023d03 100644 --- a/pkg/rag/knowledge/mgr.py +++ b/pkg/rag/knowledge/mgr.py @@ -7,6 +7,9 @@ from pkg.rag.knowledge.services.parser import FileParser from pkg.rag.knowledge.services.chunker import Chunker from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk from pkg.core import app +from pkg.rag.knowledge.services.embedder import Embedder +from pkg.rag.knowledge.services.retriever import Retriever +from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager class RAGManager: @@ -14,11 +17,12 @@ class RAGManager: def __init__(self, ap: app.Application): self.ap = ap - self.chroma_manager = None + self.chroma_manager = ChromaIndexManager() self.parser = FileParser() self.chunker = Chunker() - self.embedder = None - self.retriever = None + # Initialize Embedder with targeted model type and name + self.embedder = Embedder(model_type='third_party_api', model_name_key='bge-m3', chroma_manager=self.chroma_manager) + self.retriever = Retriever(model_type='third_party_api', model_name_key='bge-m3', chroma_manager=self.chroma_manager) async def initialize_rag_system(self): """Initializes the RAG system by creating database tables.""" @@ -140,7 +144,7 @@ class RAGManager: return [] async def store_data( - self, file_path: str, kb_name: str, file_type: str, kb_description: str = 'Default knowledge base' + self, file_path: str, kb_id: str, file_type: str, file_id: str = None ): """ Parses, chunks, embeds, and stores data from a given file into the RAG system. @@ -151,58 +155,35 @@ class RAGManager: file_obj = None try: - kb = session.query(KnowledgeBase).filter_by(name=kb_name).first() + kb = session.query(KnowledgeBase).filter_by(id=kb_id).first() if not kb: - kb = KnowledgeBase(name=kb_name, description=kb_description) - session.add(kb) - session.commit() - session.refresh(kb) - self.ap.logger.info(f"Knowledge Base '{kb_name}' created during store_data.") + self.ap.logger.info(f'Knowledge Base "{kb_id}" does not exist. ') + self.ap.logger.info(f'Created Knowledge Base with ID: {kb_id}') else: - self.ap.logger.info(f"Knowledge Base '{kb_name}' already exists.") + self.ap.logger.info(f"Knowledge Base '{kb_id}' already exists.") file_name = os.path.basename(file_path) - existing_file = session.query(File).filter_by(kb_id=kb.id, file_name=file_name).first() - if existing_file: - self.ap.logger.warning( - f"File '{file_name}' already exists in knowledge base '{kb_name}'. Skipping storage." - ) - return - - file_obj = File(kb_id=kb.id, file_name=file_name, path=file_path, file_type=file_type) - session.add(file_obj) - session.commit() - session.refresh(file_obj) - self.ap.logger.info( - f"File record '{file_name}' added to database with ID: {file_obj.id}, associated with KB ID: {kb.id}" - ) - text = await self.parser.parse(file_path) if not text: self.ap.logger.warning( - f'No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.' + f'No text extracted from file {file_path}. ' ) - session.delete(file_obj) - session.commit() return chunks_texts = await self.chunker.chunk(text) self.ap.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.") - await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts) + await self.embedder.embed_and_store(file_id=file_id, chunks=chunks_texts) self.ap.logger.info(f'Data storage process completed for file: {file_path}') except Exception as e: session.rollback() self.ap.logger.error(f'Error in store_data for file {file_path}: {str(e)}', exc_info=True) - if file_obj and file_obj.id: - try: - await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_obj.id) - except Exception as chroma_e: - self.ap.logger.warning( - f'Could not clean up ChromaDB entries for file_id {file_obj.id} after store_data failure: {chroma_e}' - ) raise finally: + if file_id: + file_obj = session.query(File).filter_by(id=file_id).first() + if file_obj: + file_obj.status = 1 session.close() async def retrieve_data(self, query: str): @@ -245,7 +226,6 @@ class RAGManager: self.ap.logger.warning( f'File with ID {file_id} not found in database. Skipping deletion of file record.' ) - session.commit() self.ap.logger.info(f'Successfully completed data deletion for file_id: {file_id}') except Exception as e: @@ -338,13 +318,13 @@ class RAGManager: self.ap.logger.error(f'Knowledge Base with UUID {knowledge_base_uuid} not found.') return - # 更新文件的 kb_id - file_to_update = session.query(File).filter_by(id=file_id).first() - if not file_to_update: - self.ap.logger.error(f'File with ID {file_id} not found.') + if not self.ap.storage_mgr.storage_provider.exists(file_id): + self.ap.logger.error(f'File with ID {file_id} does not exist.') return - - file_to_update.kb_id = kb.id + self.ap.logger.info(f'File with ID {file_id} exists, proceeding with association.') + # add new file record + file_to_update = File(id=file_id, kb_id=kb.id) + session.add(file_to_update) session.commit() self.ap.logger.info( f'Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}' @@ -356,4 +336,20 @@ class RAGManager: exc_info=True, ) finally: + # 进行文件解析 + try: + await self.store_data( + file_path = os.path.join('data', 'storage', file_id), + kb_id=knowledge_base_uuid, + file_type=os.path.splitext(file_id)[1].lstrip('.'), + file_id=file_id + ) + except Exception as store_e: + # 如果存储数据时出错,更新文件状态为失败 + file_obj = session.query(File).filter_by(id=file_id).first() + if file_obj: + file_obj.status = 2 + session.commit() + self.ap.logger.error(f'Error storing data for file ID {file_id}', exc_info=True) + session.close() diff --git a/pkg/rag/knowledge/services/chunker.py b/pkg/rag/knowledge/services/chunker.py index 17202a7a..2db7c104 100644 --- a/pkg/rag/knowledge/services/chunker.py +++ b/pkg/rag/knowledge/services/chunker.py @@ -24,33 +24,28 @@ class Chunker(BaseService): """ if not text: return [] - - # Simple whitespace-based splitting for demonstration - # For more advanced chunking, consider libraries like LangChain's text splitters - words = text.split() - chunks = [] - current_chunk = [] + # words = text.split() + # chunks = [] + # current_chunk = [] - for word in words: - current_chunk.append(word) - if len(current_chunk) > self.chunk_size: - chunks.append(" ".join(current_chunk[:self.chunk_size])) - current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:] + # for word in words: + # current_chunk.append(word) + # if len(current_chunk) > self.chunk_size: + # chunks.append(" ".join(current_chunk[:self.chunk_size])) + # current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:] - if current_chunk: - chunks.append(" ".join(current_chunk)) + # if current_chunk: + # chunks.append(" ".join(current_chunk)) # A more robust chunking strategy (e.g., using recursive character text splitter) - # from langchain.text_splitter import RecursiveCharacterTextSplitter - # text_splitter = RecursiveCharacterTextSplitter( - # chunk_size=self.chunk_size, - # chunk_overlap=self.chunk_overlap, - # length_function=len, - # is_separator_regex=False, - # ) - # return text_splitter.split_text(text) - - return [chunk for chunk in chunks if chunk.strip()] # Filter out empty chunks + from langchain.text_splitter import RecursiveCharacterTextSplitter + text_splitter = RecursiveCharacterTextSplitter( + chunk_size=self.chunk_size, + chunk_overlap=self.chunk_overlap, + length_function=len, + is_separator_regex=False, + ) + return text_splitter.split_text(text) async def chunk(self, text: str) -> List[str]: """ diff --git a/pkg/rag/knowledge/services/embedder.py b/pkg/rag/knowledge/services/embedder.py index 7e20b19a..063ae79e 100644 --- a/pkg/rag/knowledge/services/embedder.py +++ b/pkg/rag/knowledge/services/embedder.py @@ -12,7 +12,7 @@ from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager # Impor logger = logging.getLogger(__name__) class Embedder(BaseService): - def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager): + def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager = None): super().__init__() self.logger = logging.getLogger(self.__class__.__name__) self.model_type = model_type