mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 11:29:39 +08:00
fix: embbeding and chunking
This commit is contained in:
@@ -7,6 +7,9 @@ from pkg.rag.knowledge.services.parser import FileParser
|
||||
from pkg.rag.knowledge.services.chunker import Chunker
|
||||
from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk
|
||||
from pkg.core import app
|
||||
from pkg.rag.knowledge.services.embedder import Embedder
|
||||
from pkg.rag.knowledge.services.retriever import Retriever
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
|
||||
|
||||
|
||||
class RAGManager:
|
||||
@@ -14,11 +17,12 @@ class RAGManager:
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.chroma_manager = None
|
||||
self.chroma_manager = ChromaIndexManager()
|
||||
self.parser = FileParser()
|
||||
self.chunker = Chunker()
|
||||
self.embedder = None
|
||||
self.retriever = None
|
||||
# Initialize Embedder with targeted model type and name
|
||||
self.embedder = Embedder(model_type='third_party_api', model_name_key='bge-m3', chroma_manager=self.chroma_manager)
|
||||
self.retriever = Retriever(model_type='third_party_api', model_name_key='bge-m3', chroma_manager=self.chroma_manager)
|
||||
|
||||
async def initialize_rag_system(self):
|
||||
"""Initializes the RAG system by creating database tables."""
|
||||
@@ -140,7 +144,7 @@ class RAGManager:
|
||||
return []
|
||||
|
||||
async def store_data(
|
||||
self, file_path: str, kb_name: str, file_type: str, kb_description: str = 'Default knowledge base'
|
||||
self, file_path: str, kb_id: str, file_type: str, file_id: str = None
|
||||
):
|
||||
"""
|
||||
Parses, chunks, embeds, and stores data from a given file into the RAG system.
|
||||
@@ -151,58 +155,35 @@ class RAGManager:
|
||||
file_obj = None
|
||||
|
||||
try:
|
||||
kb = session.query(KnowledgeBase).filter_by(name=kb_name).first()
|
||||
kb = session.query(KnowledgeBase).filter_by(id=kb_id).first()
|
||||
if not kb:
|
||||
kb = KnowledgeBase(name=kb_name, description=kb_description)
|
||||
session.add(kb)
|
||||
session.commit()
|
||||
session.refresh(kb)
|
||||
self.ap.logger.info(f"Knowledge Base '{kb_name}' created during store_data.")
|
||||
self.ap.logger.info(f'Knowledge Base "{kb_id}" does not exist. ')
|
||||
self.ap.logger.info(f'Created Knowledge Base with ID: {kb_id}')
|
||||
else:
|
||||
self.ap.logger.info(f"Knowledge Base '{kb_name}' already exists.")
|
||||
self.ap.logger.info(f"Knowledge Base '{kb_id}' already exists.")
|
||||
|
||||
file_name = os.path.basename(file_path)
|
||||
existing_file = session.query(File).filter_by(kb_id=kb.id, file_name=file_name).first()
|
||||
if existing_file:
|
||||
self.ap.logger.warning(
|
||||
f"File '{file_name}' already exists in knowledge base '{kb_name}'. Skipping storage."
|
||||
)
|
||||
return
|
||||
|
||||
file_obj = File(kb_id=kb.id, file_name=file_name, path=file_path, file_type=file_type)
|
||||
session.add(file_obj)
|
||||
session.commit()
|
||||
session.refresh(file_obj)
|
||||
self.ap.logger.info(
|
||||
f"File record '{file_name}' added to database with ID: {file_obj.id}, associated with KB ID: {kb.id}"
|
||||
)
|
||||
|
||||
text = await self.parser.parse(file_path)
|
||||
if not text:
|
||||
self.ap.logger.warning(
|
||||
f'No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.'
|
||||
f'No text extracted from file {file_path}. '
|
||||
)
|
||||
session.delete(file_obj)
|
||||
session.commit()
|
||||
return
|
||||
|
||||
chunks_texts = await self.chunker.chunk(text)
|
||||
self.ap.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.")
|
||||
await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts)
|
||||
await self.embedder.embed_and_store(file_id=file_id, chunks=chunks_texts)
|
||||
self.ap.logger.info(f'Data storage process completed for file: {file_path}')
|
||||
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
self.ap.logger.error(f'Error in store_data for file {file_path}: {str(e)}', exc_info=True)
|
||||
if file_obj and file_obj.id:
|
||||
try:
|
||||
await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_obj.id)
|
||||
except Exception as chroma_e:
|
||||
self.ap.logger.warning(
|
||||
f'Could not clean up ChromaDB entries for file_id {file_obj.id} after store_data failure: {chroma_e}'
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
if file_id:
|
||||
file_obj = session.query(File).filter_by(id=file_id).first()
|
||||
if file_obj:
|
||||
file_obj.status = 1
|
||||
session.close()
|
||||
|
||||
async def retrieve_data(self, query: str):
|
||||
@@ -245,7 +226,6 @@ class RAGManager:
|
||||
self.ap.logger.warning(
|
||||
f'File with ID {file_id} not found in database. Skipping deletion of file record.'
|
||||
)
|
||||
|
||||
session.commit()
|
||||
self.ap.logger.info(f'Successfully completed data deletion for file_id: {file_id}')
|
||||
except Exception as e:
|
||||
@@ -338,13 +318,13 @@ class RAGManager:
|
||||
self.ap.logger.error(f'Knowledge Base with UUID {knowledge_base_uuid} not found.')
|
||||
return
|
||||
|
||||
# 更新文件的 kb_id
|
||||
file_to_update = session.query(File).filter_by(id=file_id).first()
|
||||
if not file_to_update:
|
||||
self.ap.logger.error(f'File with ID {file_id} not found.')
|
||||
if not self.ap.storage_mgr.storage_provider.exists(file_id):
|
||||
self.ap.logger.error(f'File with ID {file_id} does not exist.')
|
||||
return
|
||||
|
||||
file_to_update.kb_id = kb.id
|
||||
self.ap.logger.info(f'File with ID {file_id} exists, proceeding with association.')
|
||||
# add new file record
|
||||
file_to_update = File(id=file_id, kb_id=kb.id)
|
||||
session.add(file_to_update)
|
||||
session.commit()
|
||||
self.ap.logger.info(
|
||||
f'Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}'
|
||||
@@ -356,4 +336,20 @@ class RAGManager:
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
# 进行文件解析
|
||||
try:
|
||||
await self.store_data(
|
||||
file_path = os.path.join('data', 'storage', file_id),
|
||||
kb_id=knowledge_base_uuid,
|
||||
file_type=os.path.splitext(file_id)[1].lstrip('.'),
|
||||
file_id=file_id
|
||||
)
|
||||
except Exception as store_e:
|
||||
# 如果存储数据时出错,更新文件状态为失败
|
||||
file_obj = session.query(File).filter_by(id=file_id).first()
|
||||
if file_obj:
|
||||
file_obj.status = 2
|
||||
session.commit()
|
||||
self.ap.logger.error(f'Error storing data for file ID {file_id}', exc_info=True)
|
||||
|
||||
session.close()
|
||||
|
||||
@@ -24,33 +24,28 @@ class Chunker(BaseService):
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# Simple whitespace-based splitting for demonstration
|
||||
# For more advanced chunking, consider libraries like LangChain's text splitters
|
||||
words = text.split()
|
||||
chunks = []
|
||||
current_chunk = []
|
||||
# words = text.split()
|
||||
# chunks = []
|
||||
# current_chunk = []
|
||||
|
||||
for word in words:
|
||||
current_chunk.append(word)
|
||||
if len(current_chunk) > self.chunk_size:
|
||||
chunks.append(" ".join(current_chunk[:self.chunk_size]))
|
||||
current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:]
|
||||
# for word in words:
|
||||
# current_chunk.append(word)
|
||||
# if len(current_chunk) > self.chunk_size:
|
||||
# chunks.append(" ".join(current_chunk[:self.chunk_size]))
|
||||
# current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:]
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(" ".join(current_chunk))
|
||||
# if current_chunk:
|
||||
# chunks.append(" ".join(current_chunk))
|
||||
|
||||
# A more robust chunking strategy (e.g., using recursive character text splitter)
|
||||
# from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
# text_splitter = RecursiveCharacterTextSplitter(
|
||||
# chunk_size=self.chunk_size,
|
||||
# chunk_overlap=self.chunk_overlap,
|
||||
# length_function=len,
|
||||
# is_separator_regex=False,
|
||||
# )
|
||||
# return text_splitter.split_text(text)
|
||||
|
||||
return [chunk for chunk in chunks if chunk.strip()] # Filter out empty chunks
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
length_function=len,
|
||||
is_separator_regex=False,
|
||||
)
|
||||
return text_splitter.split_text(text)
|
||||
|
||||
async def chunk(self, text: str) -> List[str]:
|
||||
"""
|
||||
|
||||
@@ -12,7 +12,7 @@ from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager # Impor
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Embedder(BaseService):
|
||||
def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager):
|
||||
def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager = None):
|
||||
super().__init__()
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.model_type = model_type
|
||||
|
||||
Reference in New Issue
Block a user