fix: embbeding and chunking

This commit is contained in:
WangCham
2025-07-12 01:07:49 +08:00
parent fe122281fd
commit f395cac893
5 changed files with 64 additions and 71 deletions

View File

@@ -7,6 +7,9 @@ from pkg.rag.knowledge.services.parser import FileParser
from pkg.rag.knowledge.services.chunker import Chunker
from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk
from pkg.core import app
from pkg.rag.knowledge.services.embedder import Embedder
from pkg.rag.knowledge.services.retriever import Retriever
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
class RAGManager:
@@ -14,11 +17,12 @@ class RAGManager:
def __init__(self, ap: app.Application):
self.ap = ap
self.chroma_manager = None
self.chroma_manager = ChromaIndexManager()
self.parser = FileParser()
self.chunker = Chunker()
self.embedder = None
self.retriever = None
# Initialize Embedder with targeted model type and name
self.embedder = Embedder(model_type='third_party_api', model_name_key='bge-m3', chroma_manager=self.chroma_manager)
self.retriever = Retriever(model_type='third_party_api', model_name_key='bge-m3', chroma_manager=self.chroma_manager)
async def initialize_rag_system(self):
"""Initializes the RAG system by creating database tables."""
@@ -140,7 +144,7 @@ class RAGManager:
return []
async def store_data(
self, file_path: str, kb_name: str, file_type: str, kb_description: str = 'Default knowledge base'
self, file_path: str, kb_id: str, file_type: str, file_id: str = None
):
"""
Parses, chunks, embeds, and stores data from a given file into the RAG system.
@@ -151,58 +155,35 @@ class RAGManager:
file_obj = None
try:
kb = session.query(KnowledgeBase).filter_by(name=kb_name).first()
kb = session.query(KnowledgeBase).filter_by(id=kb_id).first()
if not kb:
kb = KnowledgeBase(name=kb_name, description=kb_description)
session.add(kb)
session.commit()
session.refresh(kb)
self.ap.logger.info(f"Knowledge Base '{kb_name}' created during store_data.")
self.ap.logger.info(f'Knowledge Base "{kb_id}" does not exist. ')
self.ap.logger.info(f'Created Knowledge Base with ID: {kb_id}')
else:
self.ap.logger.info(f"Knowledge Base '{kb_name}' already exists.")
self.ap.logger.info(f"Knowledge Base '{kb_id}' already exists.")
file_name = os.path.basename(file_path)
existing_file = session.query(File).filter_by(kb_id=kb.id, file_name=file_name).first()
if existing_file:
self.ap.logger.warning(
f"File '{file_name}' already exists in knowledge base '{kb_name}'. Skipping storage."
)
return
file_obj = File(kb_id=kb.id, file_name=file_name, path=file_path, file_type=file_type)
session.add(file_obj)
session.commit()
session.refresh(file_obj)
self.ap.logger.info(
f"File record '{file_name}' added to database with ID: {file_obj.id}, associated with KB ID: {kb.id}"
)
text = await self.parser.parse(file_path)
if not text:
self.ap.logger.warning(
f'No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.'
f'No text extracted from file {file_path}. '
)
session.delete(file_obj)
session.commit()
return
chunks_texts = await self.chunker.chunk(text)
self.ap.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.")
await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts)
await self.embedder.embed_and_store(file_id=file_id, chunks=chunks_texts)
self.ap.logger.info(f'Data storage process completed for file: {file_path}')
except Exception as e:
session.rollback()
self.ap.logger.error(f'Error in store_data for file {file_path}: {str(e)}', exc_info=True)
if file_obj and file_obj.id:
try:
await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_obj.id)
except Exception as chroma_e:
self.ap.logger.warning(
f'Could not clean up ChromaDB entries for file_id {file_obj.id} after store_data failure: {chroma_e}'
)
raise
finally:
if file_id:
file_obj = session.query(File).filter_by(id=file_id).first()
if file_obj:
file_obj.status = 1
session.close()
async def retrieve_data(self, query: str):
@@ -245,7 +226,6 @@ class RAGManager:
self.ap.logger.warning(
f'File with ID {file_id} not found in database. Skipping deletion of file record.'
)
session.commit()
self.ap.logger.info(f'Successfully completed data deletion for file_id: {file_id}')
except Exception as e:
@@ -338,13 +318,13 @@ class RAGManager:
self.ap.logger.error(f'Knowledge Base with UUID {knowledge_base_uuid} not found.')
return
# 更新文件的 kb_id
file_to_update = session.query(File).filter_by(id=file_id).first()
if not file_to_update:
self.ap.logger.error(f'File with ID {file_id} not found.')
if not self.ap.storage_mgr.storage_provider.exists(file_id):
self.ap.logger.error(f'File with ID {file_id} does not exist.')
return
file_to_update.kb_id = kb.id
self.ap.logger.info(f'File with ID {file_id} exists, proceeding with association.')
# add new file record
file_to_update = File(id=file_id, kb_id=kb.id)
session.add(file_to_update)
session.commit()
self.ap.logger.info(
f'Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}'
@@ -356,4 +336,20 @@ class RAGManager:
exc_info=True,
)
finally:
# 进行文件解析
try:
await self.store_data(
file_path = os.path.join('data', 'storage', file_id),
kb_id=knowledge_base_uuid,
file_type=os.path.splitext(file_id)[1].lstrip('.'),
file_id=file_id
)
except Exception as store_e:
# 如果存储数据时出错,更新文件状态为失败
file_obj = session.query(File).filter_by(id=file_id).first()
if file_obj:
file_obj.status = 2
session.commit()
self.ap.logger.error(f'Error storing data for file ID {file_id}', exc_info=True)
session.close()

View File

@@ -24,33 +24,28 @@ class Chunker(BaseService):
"""
if not text:
return []
# Simple whitespace-based splitting for demonstration
# For more advanced chunking, consider libraries like LangChain's text splitters
words = text.split()
chunks = []
current_chunk = []
# words = text.split()
# chunks = []
# current_chunk = []
for word in words:
current_chunk.append(word)
if len(current_chunk) > self.chunk_size:
chunks.append(" ".join(current_chunk[:self.chunk_size]))
current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:]
# for word in words:
# current_chunk.append(word)
# if len(current_chunk) > self.chunk_size:
# chunks.append(" ".join(current_chunk[:self.chunk_size]))
# current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:]
if current_chunk:
chunks.append(" ".join(current_chunk))
# if current_chunk:
# chunks.append(" ".join(current_chunk))
# A more robust chunking strategy (e.g., using recursive character text splitter)
# from langchain.text_splitter import RecursiveCharacterTextSplitter
# text_splitter = RecursiveCharacterTextSplitter(
# chunk_size=self.chunk_size,
# chunk_overlap=self.chunk_overlap,
# length_function=len,
# is_separator_regex=False,
# )
# return text_splitter.split_text(text)
return [chunk for chunk in chunks if chunk.strip()] # Filter out empty chunks
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
length_function=len,
is_separator_regex=False,
)
return text_splitter.split_text(text)
async def chunk(self, text: str) -> List[str]:
"""

View File

@@ -12,7 +12,7 @@ from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager # Impor
logger = logging.getLogger(__name__)
class Embedder(BaseService):
def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager):
def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager = None):
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.model_type = model_type