mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 19:37:36 +08:00
Merge remote-tracking branch 'wangcham/feat/rag' into feat/rag
This commit is contained in:
82
pkg/api/http/controller/groups/knowledge_base.py
Normal file
82
pkg/api/http/controller/groups/knowledge_base.py
Normal file
@@ -0,0 +1,82 @@
|
||||
import quart
|
||||
from .. import group
|
||||
|
||||
@group.group_class('knowledge_base', '/api/v1/knowledge/bases')
|
||||
class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
|
||||
# 定义成功方法
|
||||
def success(self, code=0, data=None, msg: str = 'ok') -> quart.Response:
|
||||
return quart.jsonify({
|
||||
"code": code,
|
||||
"data": data or {},
|
||||
"msg": msg
|
||||
})
|
||||
|
||||
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
||||
|
||||
@self.route('', methods=['POST', 'GET'])
|
||||
async def _() -> str:
|
||||
|
||||
if quart.request.method == 'GET':
|
||||
knowledge_bases = await self.ap.knowledge_base_service.get_all_knowledge_bases()
|
||||
bases_list = [
|
||||
{
|
||||
"uuid": kb.id,
|
||||
"name": kb.name,
|
||||
"description": kb.description,
|
||||
} for kb in knowledge_bases
|
||||
]
|
||||
return self.success(code=0,
|
||||
data={'bases': bases_list},
|
||||
msg='ok')
|
||||
|
||||
json_data = await quart.request.json
|
||||
knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base(
|
||||
json_data.get('name'),
|
||||
json_data.get('description')
|
||||
)
|
||||
return self.success(code=0,
|
||||
data={},
|
||||
msg='ok')
|
||||
|
||||
|
||||
@self.route('/<knowledge_base_uuid>', methods=['GET','DELETE'])
|
||||
async def _(knowledge_base_uuid: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(knowledge_base_uuid)
|
||||
|
||||
if knowledge_base is None:
|
||||
return self.http_status(404, -1, 'knowledge base not found')
|
||||
|
||||
return self.success(
|
||||
code=0,
|
||||
data={
|
||||
"name": knowledge_base.name,
|
||||
"description": knowledge_base.description,
|
||||
"uuid": knowledge_base.id
|
||||
},
|
||||
msg='ok'
|
||||
)
|
||||
elif quart.request.method == 'DELETE':
|
||||
await self.ap.knowledge_base_service.delete_kb_by_id(knowledge_base_uuid)
|
||||
return self.success(code=0, msg='ok')
|
||||
|
||||
@self.route('/<knowledge_base_uuid>/files', methods=['GET'])
|
||||
async def _(knowledge_base_uuid: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(knowledge_base_uuid)
|
||||
return self.success(code=0,data=[{
|
||||
"id": file.id,
|
||||
"file_name": file.file_name,
|
||||
"status": file.status
|
||||
} for file in files],msg='ok')
|
||||
|
||||
# delete specific file in knowledge base
|
||||
@self.route('/<knowledge_base_uuid>/files/<file_id>', methods=['DELETE'])
|
||||
async def _(knowledge_base_uuid: str, file_id: str) -> str:
|
||||
await self.ap.knowledge_base_service.delete_data_by_file_id(file_id)
|
||||
return self.success(code=0, msg='ok')
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import quart
|
||||
|
||||
from ... import group
|
||||
from .. import group
|
||||
|
||||
|
||||
@group.group_class('pipelines', '/api/v1/pipelines')
|
||||
|
||||
@@ -27,6 +27,7 @@ from ..storage import mgr as storagemgr
|
||||
from ..utils import logcache
|
||||
from . import taskmgr
|
||||
from . import entities as core_entities
|
||||
from pkg.rag.knowledge.RAG_Manager import RAG_Manager
|
||||
|
||||
|
||||
class Application:
|
||||
@@ -47,6 +48,7 @@ class Application:
|
||||
|
||||
model_mgr: llm_model_mgr.ModelManager = None
|
||||
|
||||
|
||||
# TODO 移动到 pipeline 里
|
||||
tool_mgr: llm_tool_mgr.ToolManager = None
|
||||
|
||||
@@ -99,6 +101,7 @@ class Application:
|
||||
|
||||
storage_mgr: storagemgr.StorageMgr = None
|
||||
|
||||
|
||||
# ========= HTTP Services =========
|
||||
|
||||
user_service: user_service.UserService = None
|
||||
@@ -111,6 +114,9 @@ class Application:
|
||||
|
||||
bot_service: bot_service.BotService = None
|
||||
|
||||
knowledge_base_service: RAG_Manager = None
|
||||
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@@ -145,6 +151,7 @@ class Application:
|
||||
name='http-api-controller',
|
||||
scopes=[core_entities.LifecycleControlScope.APPLICATION],
|
||||
)
|
||||
|
||||
self.task_mgr.create_task(
|
||||
never_ending(),
|
||||
name='never-ending-task',
|
||||
|
||||
@@ -19,7 +19,7 @@ class LifecycleControlScope(enum.Enum):
|
||||
APPLICATION = 'application'
|
||||
PLATFORM = 'platform'
|
||||
PLUGIN = 'plugin'
|
||||
PROVIDER = 'provider'
|
||||
PROVIDER = 'provider'
|
||||
|
||||
|
||||
class LauncherTypes(enum.Enum):
|
||||
|
||||
@@ -9,6 +9,7 @@ from ...command import cmdmgr
|
||||
from ...provider.session import sessionmgr as llm_session_mgr
|
||||
from ...provider.modelmgr import modelmgr as llm_model_mgr
|
||||
from ...provider.tools import toolmgr as llm_tool_mgr
|
||||
from ...rag.knowledge.RAG_Manager import RAG_Manager as knowledge_base_mgr
|
||||
from ...platform import botmgr as im_mgr
|
||||
from ...persistence import mgr as persistencemgr
|
||||
from ...api.http.controller import main as http_controller
|
||||
@@ -101,6 +102,10 @@ class BuildAppStage(stage.BootingStage):
|
||||
embedding_models_service_inst = model_service.EmbeddingModelsService(ap)
|
||||
ap.embedding_models_service = embedding_models_service_inst
|
||||
|
||||
knowledge_base_service_inst = knowledge_base_mgr(ap)
|
||||
await knowledge_base_service_inst.initialize_rag_system()
|
||||
ap.knowledge_base_service = knowledge_base_service_inst
|
||||
|
||||
pipeline_service_inst = pipeline_service.PipelineService(ap)
|
||||
ap.pipeline_service = pipeline_service_inst
|
||||
|
||||
|
||||
14
pkg/entity/persistence/vector.py
Normal file
14
pkg/entity/persistence/vector.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker, relationship
|
||||
from datetime import datetime
|
||||
import numpy as np # 用于处理从LargeBinary转换回来的embedding
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class Vector(Base):
|
||||
__tablename__ = 'vectors'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True)
|
||||
embedding = Column(LargeBinary) # Store embeddings as binary
|
||||
|
||||
chunk = relationship("Chunk", back_populates="vector")
|
||||
288
pkg/rag/knowledge/RAG_Manager.py
Normal file
288
pkg/rag/knowledge/RAG_Manager.py
Normal file
@@ -0,0 +1,288 @@
|
||||
# RAG_Manager class (main class, adjust imports as needed)
|
||||
from __future__ import annotations # For type hinting in Python 3.7+
|
||||
import logging
|
||||
import os
|
||||
import asyncio
|
||||
from pkg.rag.knowledge.services.parser import FileParser
|
||||
from pkg.rag.knowledge.services.chunker import Chunker
|
||||
from pkg.rag.knowledge.services.embedder import Embedder
|
||||
from pkg.rag.knowledge.services.retriever import Retriever
|
||||
from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly
|
||||
from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
|
||||
from pkg.core import app # Adjust the import path as needed
|
||||
|
||||
|
||||
class RAG_Manager:
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application,logger: logging.Logger = None):
|
||||
self.ap = ap
|
||||
self.logger = logger or logging.getLogger(__name__)
|
||||
self.embedding_model_type = None
|
||||
self.embedding_model_name = None
|
||||
self.chroma_manager = None
|
||||
self.parser = None
|
||||
self.chunker = None
|
||||
self.embedder = None
|
||||
self.retriever = None
|
||||
|
||||
async def initialize_rag_system(self):
|
||||
await asyncio.to_thread(create_db_and_tables)
|
||||
|
||||
async def create_specific_model(self, embedding_model_type: str,
|
||||
embedding_model_name: str):
|
||||
self.embedding_model_type = embedding_model_type
|
||||
self.embedding_model_name = embedding_model_name
|
||||
|
||||
try:
|
||||
model = EmbeddingModelFactory.create_model(
|
||||
model_type=self.embedding_model_type,
|
||||
model_name_key=self.embedding_model_name
|
||||
)
|
||||
self.logger.info(f"Configured embedding model '{self.embedding_model_name}' has dimension: {model.embedding_dimension}")
|
||||
except Exception as e:
|
||||
self.logger.critical(f"Failed to get dimension for configured embedding model '{self.embedding_model_name}': {e}")
|
||||
raise RuntimeError("Failed to initialize RAG_Manager due to embedding model issues.")
|
||||
|
||||
self.chroma_manager = ChromaIndexManager(collection_name=f"rag_collection_{self.embedding_model_name.replace('-', '_')}")
|
||||
|
||||
self.parser = FileParser()
|
||||
self.chunker = Chunker()
|
||||
# Pass chroma_manager to Embedder and Retriever
|
||||
self.embedder = Embedder(
|
||||
model_type=self.embedding_model_type,
|
||||
model_name_key=self.embedding_model_name,
|
||||
chroma_manager=self.chroma_manager # Inject dependency
|
||||
)
|
||||
self.retriever = Retriever(
|
||||
model_type=self.embedding_model_type,
|
||||
model_name_key=self.embedding_model_name,
|
||||
chroma_manager=self.chroma_manager # Inject dependency
|
||||
)
|
||||
|
||||
|
||||
async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = "", top_k: int = 5):
|
||||
"""
|
||||
Creates a new knowledge base with the given name and description.
|
||||
If a knowledge base with the same name already exists, it returns that one.
|
||||
"""
|
||||
try:
|
||||
def _get_kb_sync(name):
|
||||
session = SessionLocal()
|
||||
try:
|
||||
return session.query(KnowledgeBase).filter_by(name=name).first()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
kb = await asyncio.to_thread(_get_kb_sync, kb_name)
|
||||
|
||||
if not kb:
|
||||
def _add_kb_sync():
|
||||
session = SessionLocal()
|
||||
try:
|
||||
new_kb = KnowledgeBase(name=kb_name, description=kb_description, embedding_model=embedding_model, top_k=top_k)
|
||||
session.add(new_kb)
|
||||
session.commit()
|
||||
session.refresh(new_kb)
|
||||
return new_kb
|
||||
finally:
|
||||
session.close()
|
||||
kb = await asyncio.to_thread(_add_kb_sync)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True)
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True)
|
||||
raise
|
||||
|
||||
async def get_all_knowledge_bases(self):
|
||||
"""
|
||||
Retrieves all knowledge bases from the database.
|
||||
"""
|
||||
try:
|
||||
def _get_all_kbs_sync():
|
||||
session = SessionLocal()
|
||||
try:
|
||||
return session.query(KnowledgeBase).all()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
kbs = await asyncio.to_thread(_get_all_kbs_sync)
|
||||
return kbs
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error retrieving knowledge bases: {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_knowledge_base_by_id(self, kb_id: int):
|
||||
"""
|
||||
Retrieves a knowledge base by its ID.
|
||||
"""
|
||||
try:
|
||||
def _get_kb_sync(kb_id):
|
||||
session = SessionLocal()
|
||||
try:
|
||||
return session.query(KnowledgeBase).filter_by(id=kb_id).first()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
kb = await asyncio.to_thread(_get_kb_sync, kb_id)
|
||||
return kb
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error retrieving knowledge base with ID {kb_id}: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
async def get_files_by_knowledge_base(self, kb_id: int):
|
||||
try:
|
||||
def _get_files_sync(kb_id):
|
||||
session = SessionLocal()
|
||||
try:
|
||||
return session.query(File).filter_by(kb_id=kb_id).all()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
files = await asyncio.to_thread(_get_files_sync, kb_id)
|
||||
return files
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error retrieving files for knowledge base ID {kb_id}: {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
|
||||
async def store_data(self, file_path: str, kb_name: str, file_type: str, kb_description: str = "Default knowledge base"):
|
||||
self.logger.info(f"Starting data storage process for file: {file_path}")
|
||||
try:
|
||||
def _get_kb_sync(name):
|
||||
session = SessionLocal()
|
||||
try:
|
||||
return session.query(KnowledgeBase).filter_by(name=name).first()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
kb = await asyncio.to_thread(_get_kb_sync, kb_name)
|
||||
|
||||
if not kb:
|
||||
self.logger.info(f"Knowledge Base '{kb_name}' not found. Creating a new one.")
|
||||
def _add_kb_sync():
|
||||
session = SessionLocal()
|
||||
try:
|
||||
new_kb = KnowledgeBase(name=kb_name, description=kb_description)
|
||||
session.add(new_kb)
|
||||
session.commit()
|
||||
session.refresh(new_kb)
|
||||
return new_kb
|
||||
finally:
|
||||
session.close()
|
||||
kb = await asyncio.to_thread(_add_kb_sync)
|
||||
self.logger.info(f"Created Knowledge Base: {kb.name} (ID: {kb.id})")
|
||||
|
||||
def _add_file_sync(kb_id, file_name, path, file_type):
|
||||
session = SessionLocal()
|
||||
try:
|
||||
file = File(kb_id=kb_id, file_name=file_name, path=path, file_type=file_type)
|
||||
session.add(file)
|
||||
session.commit()
|
||||
session.refresh(file)
|
||||
return file
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
file_obj = await asyncio.to_thread(_add_file_sync, kb.id, os.path.basename(file_path), file_path, file_type)
|
||||
self.logger.info(f"Added file entry: {file_obj.file_name} (ID: {file_obj.id})")
|
||||
|
||||
text = await self.parser.parse(file_path)
|
||||
if not text:
|
||||
self.logger.warning(f"File {file_path} parsed to empty content. Skipping chunking and embedding.")
|
||||
# You might want to delete the file_obj from the DB here if it's empty.
|
||||
session = SessionLocal()
|
||||
try:
|
||||
session.delete(file_obj)
|
||||
session.commit()
|
||||
except Exception as del_e:
|
||||
self.logger.error(f"Failed to delete empty file_obj {file_obj.id}: {del_e}")
|
||||
finally:
|
||||
session.close()
|
||||
return
|
||||
|
||||
chunks_texts = await self.chunker.chunk(text)
|
||||
self.logger.info(f"Chunked into {len(chunks_texts)} pieces.")
|
||||
|
||||
# embed_and_store now handles both DB chunk saving and Chroma embedding
|
||||
await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts)
|
||||
|
||||
self.logger.info(f"Data storage process completed for file: {file_path}")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in store_data for file {file_path}: {str(e)}", exc_info=True)
|
||||
# Consider cleaning up partially stored data if an error occurs.
|
||||
return
|
||||
|
||||
async def retrieve_data(self, query: str):
|
||||
self.logger.info(f"Starting data retrieval process for query: '{query}'")
|
||||
try:
|
||||
retrieved_chunks = await self.retriever.retrieve(query)
|
||||
self.logger.info(f"Successfully retrieved {len(retrieved_chunks)} chunks for query.")
|
||||
return retrieved_chunks
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error in retrieve_data for query '{query}': {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def delete_data_by_file_id(self, file_id: int):
|
||||
"""
|
||||
Deletes data associated with a specific file_id from both the relational DB and Chroma.
|
||||
"""
|
||||
self.logger.info(f"Starting data deletion process for file_id: {file_id}")
|
||||
session = SessionLocal()
|
||||
try:
|
||||
# 1. Delete from Chroma
|
||||
await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id)
|
||||
|
||||
# 2. Delete chunks from relational DB
|
||||
chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all()
|
||||
for chunk in chunks_to_delete:
|
||||
session.delete(chunk)
|
||||
self.logger.info(f"Deleted {len(chunks_to_delete)} chunks from relational DB for file_id: {file_id}.")
|
||||
|
||||
# 3. Delete file entry from relational DB
|
||||
file_to_delete = session.query(File).filter_by(id=file_id).first()
|
||||
if file_to_delete:
|
||||
session.delete(file_to_delete)
|
||||
self.logger.info(f"Deleted file entry {file_id} from relational DB.")
|
||||
else:
|
||||
self.logger.warning(f"File entry {file_id} not found in relational DB.")
|
||||
|
||||
session.commit()
|
||||
self.logger.info(f"Data deletion completed for file_id: {file_id}.")
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
self.logger.error(f"Error deleting data for file_id {file_id}: {str(e)}", exc_info=True)
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
async def delete_kb_by_id(self, kb_id: int):
|
||||
"""
|
||||
Deletes a knowledge base and all associated files and chunks.
|
||||
"""
|
||||
self.logger.info(f"Starting deletion of knowledge base with ID: {kb_id}")
|
||||
session = SessionLocal()
|
||||
try:
|
||||
# 1. Get the knowledge base
|
||||
kb = session.query(KnowledgeBase).filter_by(id=kb_id).first()
|
||||
if not kb:
|
||||
self.logger.warning(f"Knowledge Base with ID {kb_id} not found.")
|
||||
return
|
||||
|
||||
# 2. Delete all files associated with this knowledge base
|
||||
files_to_delete = session.query(File).filter_by(kb_id=kb.id).all()
|
||||
for file in files_to_delete:
|
||||
await self.delete_data_by_file_id(file.id)
|
||||
|
||||
# 3. Delete the knowledge base itself
|
||||
session.delete(kb)
|
||||
session.commit()
|
||||
self.logger.info(f"Successfully deleted knowledge base with ID: {kb_id}")
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
self.logger.error(f"Error deleting knowledge base with ID {kb_id}: {str(e)}", exc_info=True)
|
||||
finally:
|
||||
session.close()
|
||||
0
pkg/rag/knowledge/services/__init__.py
Normal file
0
pkg/rag/knowledge/services/__init__.py
Normal file
26
pkg/rag/knowledge/services/base_service.py
Normal file
26
pkg/rag/knowledge/services/base_service.py
Normal file
@@ -0,0 +1,26 @@
|
||||
# 封装异步操作
|
||||
import asyncio
|
||||
import logging
|
||||
from pkg.rag.knowledge.services.database import SessionLocal
|
||||
|
||||
class BaseService:
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.db_session_factory = SessionLocal
|
||||
|
||||
async def _run_sync(self, func, *args, **kwargs):
|
||||
"""
|
||||
在单独的线程中运行同步函数。
|
||||
如果第一个参数是 session,则在 to_thread 中获取新的 session。
|
||||
"""
|
||||
|
||||
if getattr(func, '__name__', '').startswith('_db_'):
|
||||
session = await asyncio.to_thread(self.db_session_factory)
|
||||
try:
|
||||
result = await asyncio.to_thread(func, session, *args, **kwargs)
|
||||
return result
|
||||
finally:
|
||||
session.close()
|
||||
else:
|
||||
# 否则,直接运行同步函数
|
||||
return await asyncio.to_thread(func, *args, **kwargs)
|
||||
65
pkg/rag/knowledge/services/chroma_manager.py
Normal file
65
pkg/rag/knowledge/services/chroma_manager.py
Normal file
@@ -0,0 +1,65 @@
|
||||
|
||||
import numpy as np
|
||||
import logging
|
||||
from chromadb import PersistentClient
|
||||
import os
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class ChromaIndexManager:
|
||||
def __init__(self, collection_name: str = "default_collection"):
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
chroma_data_path = os.path.abspath(os.path.join(__file__, "../../../../../../data/chroma"))
|
||||
os.makedirs(chroma_data_path, exist_ok=True)
|
||||
self.client = PersistentClient(path=chroma_data_path)
|
||||
self._collection_name = collection_name
|
||||
self._collection = None
|
||||
|
||||
self.logger.info(f"ChromaIndexManager initialized. Collection name: {self._collection_name}")
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
if self._collection is None:
|
||||
self._collection = self.client.get_or_create_collection(name=self._collection_name)
|
||||
self.logger.info(f"Chroma collection '{self._collection_name}' accessed/created.")
|
||||
return self._collection
|
||||
|
||||
def add_embeddings_sync(self, file_ids: list[int], chunk_ids: list[int], embeddings: np.ndarray, documents: list[str]):
|
||||
if embeddings.shape[0] != len(chunk_ids) or embeddings.shape[0] != len(file_ids) or embeddings.shape[0] != len(documents):
|
||||
raise ValueError("Embedding, file_id, chunk_id, and document count mismatch.")
|
||||
|
||||
chroma_ids = [f"{file_id}_{chunk_id}" for file_id, chunk_id in zip(file_ids, chunk_ids)]
|
||||
metadatas = [{"file_id": fid, "chunk_id": cid} for fid, cid in zip(file_ids, chunk_ids)]
|
||||
|
||||
self.logger.debug(f"Adding {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.")
|
||||
self.collection.add(
|
||||
embeddings=embeddings.tolist(),
|
||||
ids=chroma_ids,
|
||||
metadatas=metadatas,
|
||||
documents=documents
|
||||
)
|
||||
self.logger.info(f"Added {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.")
|
||||
|
||||
def search_sync(self, query_embedding: np.ndarray, k: int = 5):
|
||||
"""
|
||||
Searches the Chroma collection for the top-k nearest neighbors.
|
||||
Args:
|
||||
query_embedding: A numpy array of the query embedding.
|
||||
k: The number of results to return.
|
||||
Returns:
|
||||
A dictionary containing query results from Chroma.
|
||||
"""
|
||||
self.logger.debug(f"Searching Chroma collection '{self._collection_name}' with k={k}.")
|
||||
results = self.collection.query(
|
||||
query_embeddings=query_embedding.tolist(),
|
||||
n_results=k,
|
||||
# REMOVE 'ids' from the include list. It's returned by default.
|
||||
include=["metadatas", "distances", "documents"]
|
||||
)
|
||||
self.logger.debug(f"Chroma search returned {len(results.get('ids', [[]])[0])} results.")
|
||||
return results
|
||||
|
||||
def delete_by_file_id_sync(self, file_id: int):
|
||||
self.logger.info(f"Deleting embeddings for file_id: {file_id} from Chroma collection '{self._collection_name}'.")
|
||||
self.collection.delete(where={"file_id": file_id})
|
||||
self.logger.info(f"Deleted embeddings for file_id: {file_id} from Chroma.")
|
||||
63
pkg/rag/knowledge/services/chunker.py
Normal file
63
pkg/rag/knowledge/services/chunker.py
Normal file
@@ -0,0 +1,63 @@
|
||||
# services/chunker.py
|
||||
import logging
|
||||
from typing import List
|
||||
from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Chunker(BaseService):
|
||||
"""
|
||||
A class for splitting long texts into smaller, overlapping chunks.
|
||||
"""
|
||||
def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50):
|
||||
super().__init__() # Initialize BaseService
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
if self.chunk_overlap >= self.chunk_size:
|
||||
self.logger.warning("Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.")
|
||||
|
||||
def _split_text_sync(self, text: str) -> List[str]:
|
||||
"""
|
||||
Synchronously splits a long text into chunks with specified overlap.
|
||||
This is a CPU-bound operation, intended to be run in a separate thread.
|
||||
"""
|
||||
if not text:
|
||||
return []
|
||||
|
||||
# Simple whitespace-based splitting for demonstration
|
||||
# For more advanced chunking, consider libraries like LangChain's text splitters
|
||||
words = text.split()
|
||||
chunks = []
|
||||
current_chunk = []
|
||||
|
||||
for word in words:
|
||||
current_chunk.append(word)
|
||||
if len(current_chunk) > self.chunk_size:
|
||||
chunks.append(" ".join(current_chunk[:self.chunk_size]))
|
||||
current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:]
|
||||
|
||||
if current_chunk:
|
||||
chunks.append(" ".join(current_chunk))
|
||||
|
||||
# A more robust chunking strategy (e.g., using recursive character text splitter)
|
||||
# from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
# text_splitter = RecursiveCharacterTextSplitter(
|
||||
# chunk_size=self.chunk_size,
|
||||
# chunk_overlap=self.chunk_overlap,
|
||||
# length_function=len,
|
||||
# is_separator_regex=False,
|
||||
# )
|
||||
# return text_splitter.split_text(text)
|
||||
|
||||
return [chunk for chunk in chunks if chunk.strip()] # Filter out empty chunks
|
||||
|
||||
async def chunk(self, text: str) -> List[str]:
|
||||
"""
|
||||
Asynchronously chunks a given text into smaller pieces.
|
||||
"""
|
||||
self.logger.info(f"Chunking text (length: {len(text)})...")
|
||||
# Run the synchronous splitting logic in a separate thread
|
||||
chunks = await self._run_sync(self._split_text_sync, text)
|
||||
self.logger.info(f"Text chunked into {len(chunks)} pieces.")
|
||||
return chunks
|
||||
58
pkg/rag/knowledge/services/database.py
Normal file
58
pkg/rag/knowledge/services/database.py
Normal file
@@ -0,0 +1,58 @@
|
||||
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker, relationship
|
||||
from datetime import datetime
|
||||
import numpy as np # 用于处理从LargeBinary转换回来的embedding
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
class KnowledgeBase(Base):
|
||||
__tablename__ = 'kb'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, index=True)
|
||||
description = Column(Text)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
embedding_model = Column(String, default="") # 默认嵌入模型
|
||||
top_k = Column(Integer, default=5) # 默认返回的top_k数量
|
||||
files = relationship("File", back_populates="knowledge_base")
|
||||
|
||||
class File(Base):
|
||||
__tablename__ = 'file'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
kb_id = Column(Integer, ForeignKey('kb.id'))
|
||||
file_name = Column(String)
|
||||
path = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
file_type = Column(String)
|
||||
status = Column(Integer, default=0) # 0: 未处理, 1: 处理中, 2: 已处理, 3: 错误
|
||||
knowledge_base = relationship("KnowledgeBase", back_populates="files")
|
||||
chunks = relationship("Chunk", back_populates="file")
|
||||
|
||||
class Chunk(Base):
|
||||
__tablename__ = 'chunks'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
file_id = Column(Integer, ForeignKey('file.id'))
|
||||
text = Column(Text)
|
||||
|
||||
file = relationship("File", back_populates="chunks")
|
||||
vector = relationship("Vector", uselist=False, back_populates="chunk") # One-to-one
|
||||
|
||||
class Vector(Base):
|
||||
__tablename__ = 'vectors'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True)
|
||||
embedding = Column(LargeBinary) # Store embeddings as binary
|
||||
|
||||
chunk = relationship("Chunk", back_populates="vector")
|
||||
|
||||
# 数据库连接
|
||||
DATABASE_URL = "sqlite:///./knowledge_base.db" # 生产环境请更换为 PostgreSQL/MySQL
|
||||
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False} if "sqlite" in DATABASE_URL else {})
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# 创建所有表 (可以在应用启动时执行一次)
|
||||
def create_db_and_tables():
|
||||
Base.metadata.create_all(bind=engine)
|
||||
print("Database tables created/checked.")
|
||||
|
||||
# 定义嵌入维度(请根据你实际使用的模型调整)
|
||||
EMBEDDING_DIM = 1024
|
||||
93
pkg/rag/knowledge/services/embedder.py
Normal file
93
pkg/rag/knowledge/services/embedder.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# services/embedder.py
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
from typing import List
|
||||
from sqlalchemy.orm import Session
|
||||
from pkg.rag.knowledge.services.base_service import BaseService
|
||||
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
|
||||
from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager # Import the manager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Embedder(BaseService):
|
||||
def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager):
|
||||
super().__init__()
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.model_type = model_type
|
||||
self.model_name_key = model_name_key
|
||||
self.chroma_manager = chroma_manager # Dependency Injection
|
||||
|
||||
self.embedding_model: BaseEmbeddingModel = self._load_embedding_model()
|
||||
|
||||
def _load_embedding_model(self) -> BaseEmbeddingModel:
|
||||
self.logger.info(f"Loading embedding model: type={self.model_type}, name_key={self.model_name_key}...")
|
||||
try:
|
||||
model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key)
|
||||
self.logger.info(f"Embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}")
|
||||
return model
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load embedding model '{self.model_name_key}': {e}")
|
||||
raise
|
||||
|
||||
def _db_save_chunks_sync(self, session: Session, file_id: int, chunks_texts: List[str]):
|
||||
"""
|
||||
Saves chunks to the relational database and returns the created Chunk objects.
|
||||
This function assumes it's called within a context where the session
|
||||
will be committed/rolled back and closed by the caller.
|
||||
"""
|
||||
self.logger.debug(f"Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).")
|
||||
chunk_objects = []
|
||||
for text in chunks_texts:
|
||||
chunk = Chunk(file_id=file_id, text=text)
|
||||
session.add(chunk)
|
||||
chunk_objects.append(chunk)
|
||||
session.flush() # This populates the .id attribute for each new chunk object
|
||||
self.logger.debug(f"Successfully added {len(chunk_objects)} chunk entries to DB.")
|
||||
return chunk_objects
|
||||
|
||||
async def embed_and_store(self, file_id: int, chunks: List[str]):
|
||||
if not self.embedding_model:
|
||||
raise RuntimeError("Embedding model not loaded. Please check Embedder initialization.")
|
||||
|
||||
self.logger.info(f"Embedding {len(chunks)} chunks for file_id: {file_id} using {self.model_name_key}...")
|
||||
|
||||
session = SessionLocal() # Start a session that will live for the whole operation
|
||||
chunk_objects = []
|
||||
try:
|
||||
# 1. Save chunks to the relational database first to get their IDs
|
||||
# We call _db_save_chunks_sync directly without _run_sync's session management
|
||||
# because we manage the session here across multiple async calls.
|
||||
chunk_objects = await asyncio.to_thread(self._db_save_chunks_sync, session, file_id, chunks)
|
||||
session.commit() # Commit chunks to make their IDs permanent and accessible
|
||||
|
||||
if not chunk_objects:
|
||||
self.logger.warning(f"No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.")
|
||||
return []
|
||||
|
||||
# 2. Generate embeddings
|
||||
embeddings: List[List[float]] = await self.embedding_model.embed_documents(chunks)
|
||||
embeddings_np = np.array(embeddings, dtype=np.float32)
|
||||
|
||||
if embeddings_np.shape[1] != self.embedding_model.embedding_dimension:
|
||||
self.logger.error(f"Mismatch in embedding dimension: Model returned {embeddings_np.shape[1]}, expected {self.embedding_model.embedding_dimension}. Aborting storage.")
|
||||
raise ValueError("Embedding dimension mismatch during embedding process.")
|
||||
|
||||
self.logger.info("Saving embeddings to Chroma...")
|
||||
chunk_ids = [c.id for c in chunk_objects] # Now safe to access .id because session is still open and committed
|
||||
file_ids_for_chroma = [file_id] * len(chunk_ids)
|
||||
|
||||
await self._run_sync( # Use _run_sync for the Chroma operation, as it's a sync call
|
||||
self.chroma_manager.add_embeddings_sync,
|
||||
file_ids_for_chroma, chunk_ids, embeddings_np, chunks # Pass original chunks texts for documents
|
||||
)
|
||||
self.logger.info(f"Successfully saved {len(chunk_objects)} embeddings to Chroma.")
|
||||
return chunk_objects
|
||||
|
||||
except Exception as e:
|
||||
session.rollback() # Rollback on any error
|
||||
self.logger.error(f"Failed to process and store data for file_id {file_id}: {e}", exc_info=True)
|
||||
raise # Re-raise the exception to propagate it
|
||||
finally:
|
||||
session.close() # Ensure the session is always closed
|
||||
223
pkg/rag/knowledge/services/embedding_models.py
Normal file
223
pkg/rag/knowledge/services/embedding_models.py
Normal file
@@ -0,0 +1,223 @@
|
||||
# services/embedding_models.py
|
||||
|
||||
import os
|
||||
from typing import Dict, Any, List, Type, Optional
|
||||
import logging
|
||||
import aiohttp # Import aiohttp for asynchronous requests
|
||||
import asyncio
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Base class for all embedding models
|
||||
class BaseEmbeddingModel:
|
||||
def __init__(self, model_name: str):
|
||||
self.model_name = model_name
|
||||
self._embedding_dimension = None
|
||||
|
||||
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronously embeds a list of texts."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def embed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronously embeds a single query text."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def embedding_dimension(self) -> int:
|
||||
"""Returns the embedding dimension of the model."""
|
||||
if self._embedding_dimension is None:
|
||||
raise NotImplementedError("Embedding dimension not set for this model.")
|
||||
return self._embedding_dimension
|
||||
|
||||
class EmbeddingModelFactory:
|
||||
@staticmethod
|
||||
def create_model(model_type: str, model_name_key: str) -> BaseEmbeddingModel:
|
||||
"""
|
||||
Factory method to create an embedding model instance.
|
||||
Currently only supports 'third_party_api' types.
|
||||
"""
|
||||
if model_name_key not in EMBEDDING_MODEL_CONFIGS:
|
||||
raise ValueError(f"Embedding model configuration '{model_name_key}' not found in EMBEDDING_MODEL_CONFIGS.")
|
||||
|
||||
config = EMBEDDING_MODEL_CONFIGS[model_name_key]
|
||||
|
||||
if config['type'] == "third_party_api":
|
||||
required_keys = ['api_endpoint', 'headers', 'payload_template', 'embedding_dimension']
|
||||
if not all(key in config for key in required_keys):
|
||||
raise ValueError(f"Missing configuration keys for third_party_api model '{model_name_key}'. Required: {required_keys}")
|
||||
|
||||
# Retrieve model_name from config if it differs from model_name_key
|
||||
# Some APIs expect a specific 'model' value in the payload that might be different from the key
|
||||
api_model_name = config.get('model_name', model_name_key)
|
||||
|
||||
return ThirdPartyAPIEmbeddingModel(
|
||||
model_name=api_model_name, # Use the model_name from config or the key
|
||||
api_endpoint=config['api_endpoint'],
|
||||
headers=config['headers'],
|
||||
payload_template=config['payload_template'],
|
||||
embedding_dimension=config['embedding_dimension']
|
||||
)
|
||||
|
||||
class SentenceTransformerEmbeddingModel(BaseEmbeddingModel):
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__(model_name)
|
||||
try:
|
||||
# SentenceTransformer is inherently synchronous, but we'll wrap its calls
|
||||
# in async methods. The actual computation will still block the event loop
|
||||
# if not run in a separate thread/process, but this keeps the API consistent.
|
||||
self.model = SentenceTransformer(model_name)
|
||||
self._embedding_dimension = self.model.get_sentence_embedding_dimension()
|
||||
logger.info(f"Initialized SentenceTransformer model '{model_name}' with dimension {self._embedding_dimension}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load SentenceTransformer model {model_name}: {e}")
|
||||
raise
|
||||
|
||||
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
# For CPU-bound tasks like local model inference, consider running in a thread pool
|
||||
# to prevent blocking the event loop for long operations.
|
||||
# For simplicity here, we'll call it directly.
|
||||
return self.model.encode(texts).tolist()
|
||||
|
||||
async def embed_query(self, text: str) -> List[float]:
|
||||
return self.model.encode(text).tolist()
|
||||
|
||||
|
||||
class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel):
|
||||
def __init__(self, model_name: str, api_endpoint: str, headers: Dict[str, str], payload_template: Dict[str, Any], embedding_dimension: int):
|
||||
super().__init__(model_name)
|
||||
self.api_endpoint = api_endpoint
|
||||
self.headers = headers
|
||||
self.payload_template = payload_template
|
||||
self._embedding_dimension = embedding_dimension
|
||||
self.session = None # aiohttp client session will be initialized on first use or in a context manager
|
||||
logger.info(f"Initialized ThirdPartyAPIEmbeddingModel '{model_name}' for async calls to {api_endpoint} with dimension {embedding_dimension}")
|
||||
|
||||
async def _get_session(self):
|
||||
"""Lazily create or return the aiohttp client session."""
|
||||
if self.session is None or self.session.closed:
|
||||
self.session = aiohttp.ClientSession()
|
||||
return self.session
|
||||
|
||||
async def close_session(self):
|
||||
"""Explicitly close the aiohttp client session."""
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
logger.info(f"Closed aiohttp session for model {self.model_name}")
|
||||
|
||||
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronously embeds a list of texts using the third-party API."""
|
||||
session = await self._get_session()
|
||||
embeddings = []
|
||||
tasks = []
|
||||
for text in texts:
|
||||
payload = self.payload_template.copy()
|
||||
if 'input' in payload:
|
||||
payload['input'] = text
|
||||
elif 'texts' in payload:
|
||||
payload['texts'] = [text]
|
||||
else:
|
||||
raise ValueError("Payload template does not contain expected text input key.")
|
||||
|
||||
tasks.append(self._make_api_request(session, payload))
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for i, res in enumerate(results):
|
||||
if isinstance(res, Exception):
|
||||
logger.error(f"Error embedding text '{texts[i][:50]}...': {res}")
|
||||
# Depending on your error handling strategy, you might:
|
||||
# - Append None or an empty list
|
||||
# - Re-raise the exception to stop processing
|
||||
# - Log and skip, then continue
|
||||
embeddings.append([0.0] * self.embedding_dimension) # Append dummy embedding or handle failure
|
||||
else:
|
||||
embeddings.append(res)
|
||||
|
||||
return embeddings
|
||||
|
||||
async def _make_api_request(self, session: aiohttp.ClientSession, payload: Dict[str, Any]) -> List[float]:
|
||||
"""Helper to make an asynchronous API request and extract embedding."""
|
||||
try:
|
||||
async with session.post(self.api_endpoint, headers=self.headers, json=payload) as response:
|
||||
response.raise_for_status() # Raise an exception for HTTP errors (4xx, 5xx)
|
||||
api_response = await response.json()
|
||||
|
||||
# Adjust this based on your API's actual response structure
|
||||
if "data" in api_response and len(api_response["data"]) > 0 and "embedding" in api_response["data"][0]:
|
||||
embedding = api_response["data"][0]["embedding"]
|
||||
if len(embedding) != self.embedding_dimension:
|
||||
logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.")
|
||||
return embedding
|
||||
elif "embeddings" in api_response and isinstance(api_response["embeddings"], list) and api_response["embeddings"]:
|
||||
embedding = api_response["embeddings"][0]
|
||||
if len(embedding) != self.embedding_dimension:
|
||||
logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.")
|
||||
return embedding
|
||||
else:
|
||||
raise ValueError(f"Unexpected API response structure: {api_response}")
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise ConnectionError(f"API request failed: {e}") from e
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Error processing API response: {e}") from e
|
||||
|
||||
|
||||
async def embed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronously embeds a single query text."""
|
||||
results = await self.embed_documents([text])
|
||||
if results:
|
||||
return results[0]
|
||||
return [] # Or raise an error if embedding a query must always succeed
|
||||
|
||||
# --- Embedding Model Configuration ---
|
||||
EMBEDDING_MODEL_CONFIGS: Dict[str, Dict[str, Any]] = {
|
||||
"MiniLM": { # Example for a local Sentence Transformer model
|
||||
"type": "sentence_transformer",
|
||||
"model_name": "sentence-transformers/all-MiniLM-L6-v2"
|
||||
},
|
||||
"bge-m3": { # Example for a third-party API model
|
||||
"type": "third_party_api",
|
||||
"model_name": "bge-m3",
|
||||
"api_endpoint": "https://api.qhaigc.net/v1/embeddings",
|
||||
"headers": {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.getenv('rag_api_key')}"
|
||||
},
|
||||
"payload_template": {
|
||||
"model": "bge-m3",
|
||||
"input": ""
|
||||
},
|
||||
"embedding_dimension": 1024
|
||||
},
|
||||
"OpenAI-Ada-002": {
|
||||
"type": "third_party_api",
|
||||
"model_name": "text-embedding-ada-002",
|
||||
"api_endpoint": "https://api.openai.com/v1/embeddings",
|
||||
"headers": {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}" # Ensure OPENAI_API_KEY is set
|
||||
},
|
||||
"payload_template": {
|
||||
"model": "text-embedding-ada-002",
|
||||
"input": "" # Text will be injected here
|
||||
},
|
||||
"embedding_dimension": 1536
|
||||
},
|
||||
"OpenAI-Embedding-3-Small": {
|
||||
"type": "third_party_api",
|
||||
"model_name": "text-embedding-3-small",
|
||||
"api_endpoint": "https://api.openai.com/v1/embeddings",
|
||||
"headers": {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"
|
||||
},
|
||||
"payload_template": {
|
||||
"model": "text-embedding-3-small",
|
||||
"input": "",
|
||||
# "dimensions": 512 # Optional: uncomment if you want a specific output dimension
|
||||
},
|
||||
"embedding_dimension": 1536 # Default max dimension for text-embedding-3-small
|
||||
},
|
||||
}
|
||||
288
pkg/rag/knowledge/services/parser.py
Normal file
288
pkg/rag/knowledge/services/parser.py
Normal file
@@ -0,0 +1,288 @@
|
||||
|
||||
import PyPDF2
|
||||
from docx import Document
|
||||
import pandas as pd
|
||||
import csv
|
||||
import chardet
|
||||
from typing import Union, List, Callable, Any
|
||||
import logging
|
||||
import markdown
|
||||
from bs4 import BeautifulSoup
|
||||
import ebooklib
|
||||
from ebooklib import epub
|
||||
import re
|
||||
import asyncio # Import asyncio for async operations
|
||||
import os
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class FileParser:
|
||||
"""
|
||||
A robust file parser class to extract text content from various document formats.
|
||||
It supports TXT, PDF, DOCX, XLSX, CSV, Markdown, HTML, and EPUB files.
|
||||
All core file reading operations are designed to be run synchronously in a thread pool
|
||||
to avoid blocking the asyncio event loop.
|
||||
"""
|
||||
def __init__(self):
|
||||
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
async def _run_sync(self, sync_func: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
Runs a synchronous function in a separate thread to prevent blocking the event loop.
|
||||
This is a general utility method for wrapping blocking I/O operations.
|
||||
"""
|
||||
try:
|
||||
return await asyncio.to_thread(sync_func, *args, **kwargs)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error running synchronous function {sync_func.__name__}: {e}")
|
||||
raise
|
||||
|
||||
async def parse(self, file_path: str) -> Union[str, None]:
|
||||
"""
|
||||
Parses the file based on its extension and returns the extracted text content.
|
||||
This is the main asynchronous entry point for parsing.
|
||||
|
||||
Args:
|
||||
file_path (str): The path to the file to be parsed.
|
||||
|
||||
Returns:
|
||||
Union[str, None]: The extracted text content as a single string, or None if parsing fails.
|
||||
"""
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
self.logger.error(f"Invalid file path provided: {file_path}")
|
||||
return None
|
||||
|
||||
file_extension = file_path.split('.')[-1].lower()
|
||||
parser_method = getattr(self, f'_parse_{file_extension}', None)
|
||||
|
||||
if parser_method is None:
|
||||
self.logger.error(f"Unsupported file format: {file_extension} for file {file_path}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Pass file_path to the specific parser methods
|
||||
return await parser_method(file_path)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to parse {file_extension} file {file_path}: {e}")
|
||||
return None
|
||||
|
||||
# --- Helper for reading files with encoding detection ---
|
||||
async def _read_file_content(self, file_path: str, mode: str = 'r') -> Union[str, bytes]:
|
||||
"""
|
||||
Reads a file with automatic encoding detection, ensuring the synchronous
|
||||
file read operation runs in a separate thread.
|
||||
"""
|
||||
def _read_sync():
|
||||
with open(file_path, 'rb') as file:
|
||||
raw_data = file.read()
|
||||
detected = chardet.detect(raw_data)
|
||||
encoding = detected['encoding'] or 'utf-8'
|
||||
|
||||
if mode == 'r':
|
||||
return raw_data.decode(encoding, errors='ignore')
|
||||
return raw_data # For binary mode
|
||||
|
||||
return await self._run_sync(_read_sync)
|
||||
|
||||
# --- Specific Parser Methods ---
|
||||
|
||||
async def _parse_txt(self, file_path: str) -> str:
|
||||
"""Parses a TXT file and returns its content."""
|
||||
self.logger.info(f"Parsing TXT file: {file_path}")
|
||||
return await self._read_file_content(file_path, mode='r')
|
||||
|
||||
async def _parse_pdf(self, file_path: str) -> str:
|
||||
"""Parses a PDF file and returns its text content."""
|
||||
self.logger.info(f"Parsing PDF file: {file_path}")
|
||||
def _parse_pdf_sync():
|
||||
text_content = []
|
||||
with open(file_path, 'rb') as file:
|
||||
pdf_reader = PyPDF2.PdfReader(file)
|
||||
for page in pdf_reader.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
text_content.append(text)
|
||||
return '\n'.join(text_content)
|
||||
return await self._run_sync(_parse_pdf_sync)
|
||||
|
||||
async def _parse_docx(self, file_path: str) -> str:
|
||||
"""Parses a DOCX file and returns its text content."""
|
||||
self.logger.info(f"Parsing DOCX file: {file_path}")
|
||||
def _parse_docx_sync():
|
||||
doc = Document(file_path)
|
||||
text_content = [paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()]
|
||||
return '\n'.join(text_content)
|
||||
return await self._run_sync(_parse_docx_sync)
|
||||
|
||||
async def _parse_doc(self, file_path: str) -> str:
|
||||
"""Handles .doc files, explicitly stating lack of direct support."""
|
||||
self.logger.warning(f"Direct .doc parsing is not supported for {file_path}. Please convert to .docx first.")
|
||||
raise NotImplementedError("Direct .doc parsing not supported. Please convert to .docx first.")
|
||||
|
||||
async def _parse_xlsx(self, file_path: str) -> str:
|
||||
"""Parses an XLSX file, returning text from all sheets."""
|
||||
self.logger.info(f"Parsing XLSX file: {file_path}")
|
||||
def _parse_xlsx_sync():
|
||||
excel_file = pd.ExcelFile(file_path)
|
||||
all_sheet_content = []
|
||||
for sheet_name in excel_file.sheet_names:
|
||||
df = pd.read_excel(file_path, sheet_name=sheet_name)
|
||||
sheet_text = f"--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n"
|
||||
all_sheet_content.append(sheet_text)
|
||||
return '\n'.join(all_sheet_content)
|
||||
return await self._run_sync(_parse_xlsx_sync)
|
||||
|
||||
async def _parse_csv(self, file_path: str) -> str:
|
||||
"""Parses a CSV file and returns its content as a string."""
|
||||
self.logger.info(f"Parsing CSV file: {file_path}")
|
||||
def _parse_csv_sync():
|
||||
# pd.read_csv can often detect encoding, but explicit detection is safer
|
||||
raw_data = self._read_file_content(file_path, mode='rb') # Note: this will need to be await outside this sync function
|
||||
# For simplicity, we'll let pandas handle encoding internally after a raw read.
|
||||
# A more robust solution might pass encoding directly to pd.read_csv after detection.
|
||||
detected = chardet.detect(open(file_path, 'rb').read())
|
||||
encoding = detected['encoding'] or 'utf-8'
|
||||
df = pd.read_csv(file_path, encoding=encoding)
|
||||
return df.to_string(index=False)
|
||||
return await self._run_sync(_parse_csv_sync)
|
||||
|
||||
async def _parse_markdown(self, file_path: str) -> str:
|
||||
"""Parses a Markdown file, converting it to structured plain text."""
|
||||
self.logger.info(f"Parsing Markdown file: {file_path}")
|
||||
def _parse_markdown_sync():
|
||||
md_content = self._read_file_content(file_path, mode='r') # This is a synchronous call within a sync function
|
||||
html_content = markdown.markdown(
|
||||
md_content,
|
||||
extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code']
|
||||
)
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
text_parts = []
|
||||
for element in soup.children:
|
||||
if element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
|
||||
level = int(element.name[1])
|
||||
text_parts.append('#' * level + ' ' + element.get_text().strip())
|
||||
elif element.name == 'p':
|
||||
text = element.get_text().strip()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
elif element.name in ['ul', 'ol']:
|
||||
for li in element.find_all('li'):
|
||||
text_parts.append(f"* {li.get_text().strip()}")
|
||||
elif element.name == 'pre':
|
||||
code_block = element.get_text().strip()
|
||||
if code_block:
|
||||
text_parts.append(f"```\n{code_block}\n```")
|
||||
elif element.name == 'table':
|
||||
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
|
||||
if table_str:
|
||||
text_parts.append(table_str)
|
||||
elif element.name:
|
||||
text = element.get_text(separator=' ', strip=True)
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts))
|
||||
return cleaned_text.strip()
|
||||
return await self._run_sync(_parse_markdown_sync)
|
||||
|
||||
async def _parse_html(self, file_path: str) -> str:
|
||||
"""Parses an HTML file, extracting structured plain text."""
|
||||
self.logger.info(f"Parsing HTML file: {file_path}")
|
||||
def _parse_html_sync():
|
||||
html_content = self._read_file_content(file_path, mode='r') # Sync call within sync function
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
for script_or_style in soup(["script", "style"]):
|
||||
script_or_style.decompose()
|
||||
text_parts = []
|
||||
for element in soup.body.children if soup.body else soup.children:
|
||||
if element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
|
||||
level = int(element.name[1])
|
||||
text_parts.append('#' * level + ' ' + element.get_text().strip())
|
||||
elif element.name == 'p':
|
||||
text = element.get_text().strip()
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
elif element.name in ['ul', 'ol']:
|
||||
for li in element.find_all('li'):
|
||||
text = li.get_text().strip()
|
||||
if text:
|
||||
text_parts.append(f"* {text}")
|
||||
elif element.name == 'table':
|
||||
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
|
||||
if table_str:
|
||||
text_parts.append(table_str)
|
||||
elif element.name:
|
||||
text = element.get_text(separator=' ', strip=True)
|
||||
if text:
|
||||
text_parts.append(text)
|
||||
cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts))
|
||||
return cleaned_text.strip()
|
||||
return await self._run_sync(_parse_html_sync)
|
||||
|
||||
async def _parse_epub(self, file_path: str) -> str:
|
||||
"""Parses an EPUB file, extracting metadata and content."""
|
||||
self.logger.info(f"Parsing EPUB file: {file_path}")
|
||||
def _parse_epub_sync():
|
||||
book = epub.read_epub(file_path)
|
||||
text_content = []
|
||||
title_meta = book.get_metadata('DC', 'title')
|
||||
if title_meta:
|
||||
text_content.append(f"Title: {title_meta[0][0]}")
|
||||
creator_meta = book.get_metadata('DC', 'creator')
|
||||
if creator_meta:
|
||||
text_content.append(f"Author: {creator_meta[0][0]}")
|
||||
date_meta = book.get_metadata('DC', 'date')
|
||||
if date_meta:
|
||||
text_content.append(f"Publish Date: {date_meta[0][0]}")
|
||||
toc = book.get_toc()
|
||||
if toc:
|
||||
text_content.append("\n--- Table of Contents ---")
|
||||
self._add_toc_items_sync(toc, text_content, level=0) # Call sync helper
|
||||
text_content.append("--- End of Table of Contents ---\n")
|
||||
for item in book.get_items():
|
||||
if item.get_type() == ebooklib.ITEM_DOCUMENT:
|
||||
html_content = item.get_content().decode('utf-8', errors='ignore')
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
for junk in soup(["script", "style", "nav", "header", "footer"]):
|
||||
junk.decompose()
|
||||
text = soup.get_text(separator='\n', strip=True)
|
||||
text = re.sub(r'\n\s*\n', '\n\n', text)
|
||||
if text:
|
||||
text_content.append(text)
|
||||
return re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_content)).strip()
|
||||
return await self._run_sync(_parse_epub_sync)
|
||||
|
||||
def _add_toc_items_sync(self, toc_list: list, text_content: list, level: int):
|
||||
"""Recursively adds TOC items to text_content (synchronous helper)."""
|
||||
indent = ' ' * level
|
||||
for item in toc_list:
|
||||
if isinstance(item, tuple):
|
||||
chapter, subchapters = item
|
||||
text_content.append(f"{indent}- {chapter.title}")
|
||||
self._add_toc_items_sync(subchapters, text_content, level + 1)
|
||||
else:
|
||||
text_content.append(f"{indent}- {item.title}")
|
||||
|
||||
def _extract_table_to_markdown_sync(self, table_element: BeautifulSoup) -> str:
|
||||
"""Helper to convert a BeautifulSoup table element into a Markdown table string (synchronous)."""
|
||||
headers = [th.get_text().strip() for th in table_element.find_all('th')]
|
||||
rows = []
|
||||
for tr in table_element.find_all('tr'):
|
||||
cells = [td.get_text().strip() for td in tr.find_all('td')]
|
||||
if cells:
|
||||
rows.append(cells)
|
||||
|
||||
if not headers and not rows:
|
||||
return ""
|
||||
|
||||
table_lines = []
|
||||
if headers:
|
||||
table_lines.append(' | '.join(headers))
|
||||
table_lines.append(' | '.join(['---'] * len(headers)))
|
||||
|
||||
for row_cells in rows:
|
||||
padded_cells = row_cells + [''] * (len(headers) - len(row_cells)) if headers else row_cells
|
||||
table_lines.append(' | '.join(padded_cells))
|
||||
|
||||
return '\n'.join(table_lines)
|
||||
106
pkg/rag/knowledge/services/retriever.py
Normal file
106
pkg/rag/knowledge/services/retriever.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# services/retriever.py
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np # Make sure numpy is imported
|
||||
from typing import List, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from pkg.rag.knowledge.services.base_service import BaseService
|
||||
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
|
||||
from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Retriever(BaseService):
|
||||
def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager):
|
||||
super().__init__()
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.model_type = model_type
|
||||
self.model_name_key = model_name_key
|
||||
self.chroma_manager = chroma_manager
|
||||
|
||||
self.embedding_model: BaseEmbeddingModel = self._load_embedding_model()
|
||||
|
||||
def _load_embedding_model(self) -> BaseEmbeddingModel:
|
||||
self.logger.info(f"Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...")
|
||||
try:
|
||||
model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key)
|
||||
self.logger.info(f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}")
|
||||
return model
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load retriever embedding model '{self.model_name_key}': {e}")
|
||||
raise
|
||||
|
||||
async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
|
||||
if not self.embedding_model:
|
||||
raise RuntimeError("Retriever embedding model not loaded. Please check Retriever initialization.")
|
||||
|
||||
self.logger.info(f"Retrieving for query: '{query}' with k={k} using {self.model_name_key}")
|
||||
|
||||
query_embedding: List[float] = await self.embedding_model.embed_query(query)
|
||||
query_embedding_np = np.array([query_embedding], dtype=np.float32)
|
||||
|
||||
chroma_results = await self._run_sync(
|
||||
self.chroma_manager.search_sync,
|
||||
query_embedding_np, k
|
||||
)
|
||||
|
||||
# 'ids' is always returned by ChromaDB, even if not explicitly in 'include'
|
||||
matched_chroma_ids = chroma_results.get("ids", [[]])[0]
|
||||
distances = chroma_results.get("distances", [[]])[0]
|
||||
chroma_metadatas = chroma_results.get("metadatas", [[]])[0]
|
||||
chroma_documents = chroma_results.get("documents", [[]])[0]
|
||||
|
||||
if not matched_chroma_ids:
|
||||
self.logger.info("No relevant chunks found in Chroma.")
|
||||
return []
|
||||
|
||||
db_chunk_ids = []
|
||||
for metadata in chroma_metadatas:
|
||||
if "chunk_id" in metadata:
|
||||
db_chunk_ids.append(metadata["chunk_id"])
|
||||
else:
|
||||
self.logger.warning(f"Metadata missing 'chunk_id': {metadata}. Skipping this entry.")
|
||||
|
||||
if not db_chunk_ids:
|
||||
self.logger.warning("No valid chunk_ids extracted from Chroma results metadata.")
|
||||
return []
|
||||
|
||||
self.logger.info(f"Fetching {len(db_chunk_ids)} chunk details from relational database...")
|
||||
chunks_from_db = await self._run_sync(
|
||||
lambda cids: self._db_get_chunks_sync(SessionLocal(), cids), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync
|
||||
db_chunk_ids
|
||||
)
|
||||
|
||||
chunk_map = {chunk.id: chunk for chunk in chunks_from_db}
|
||||
results_list: List[Dict[str, Any]] = []
|
||||
|
||||
for i, chroma_id in enumerate(matched_chroma_ids):
|
||||
try:
|
||||
# Ensure original_chunk_id is int for DB lookup
|
||||
original_chunk_id = int(chroma_id.split('_')[-1])
|
||||
except (ValueError, IndexError):
|
||||
self.logger.warning(f"Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.")
|
||||
continue
|
||||
|
||||
chunk_text_from_chroma = chroma_documents[i]
|
||||
distance = float(distances[i])
|
||||
file_id_from_chroma = chroma_metadatas[i].get("file_id")
|
||||
|
||||
chunk_from_db = chunk_map.get(original_chunk_id)
|
||||
|
||||
results_list.append({
|
||||
"chunk_id": original_chunk_id,
|
||||
"text": chunk_from_db.text if chunk_from_db else chunk_text_from_chroma,
|
||||
"distance": distance,
|
||||
"file_id": file_id_from_chroma
|
||||
})
|
||||
|
||||
self.logger.info(f"Retrieved {len(results_list)} chunks.")
|
||||
return results_list
|
||||
|
||||
def _db_get_chunks_sync(self, session: Session, chunk_ids: List[int]) -> List[Chunk]:
|
||||
self.logger.debug(f"Fetching {len(chunk_ids)} chunk details from database (sync).")
|
||||
chunks = session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all()
|
||||
session.close()
|
||||
return chunks
|
||||
0
pkg/rag/knowledge/utils/crawler.py
Normal file
0
pkg/rag/knowledge/utils/crawler.py
Normal file
@@ -49,6 +49,17 @@ dependencies = [
|
||||
"ruff>=0.11.9",
|
||||
"pre-commit>=4.2.0",
|
||||
"uv>=0.7.11",
|
||||
"PyPDF2>=3.0.1",
|
||||
"python-docx>=1.1.0",
|
||||
"pandas>=2.2.2",
|
||||
"chardet>=5.2.0",
|
||||
"markdown>=3.6",
|
||||
"beautifulsoup4>=4.12.3",
|
||||
"ebooklib>=0.18",
|
||||
"html2text>=2024.2.26",
|
||||
"langchain>=0.2.0",
|
||||
"chromadb>=0.4.24",
|
||||
"sentence-transformers>=2.6.1",
|
||||
]
|
||||
keywords = [
|
||||
"bot",
|
||||
|
||||
Reference in New Issue
Block a user