This commit is contained in:
WangCham
2025-06-25 14:32:53 +08:00
committed by Junyan Qin
parent 348f6d9eaa
commit 4bcc06c955
14 changed files with 1305 additions and 0 deletions

View File

@@ -0,0 +1,83 @@
import quart
from __future__ import annotations
from .. import group
@group.group_class('knowledge_base', '/api/v1/knowledge/bases')
class KnowledgeBaseRouterGroup(group.RouterGroup):
# 定义成功方法
def success(self, code=0, data=None, msg: str = 'ok') -> quart.Response:
return quart.jsonify({
"code": code,
"data": data or {},
"msg": msg
})
async def initialize(self) -> None:
rag = self.ap.knowledge_base_service.RAG_Manager()
@self.route('', methods=['POST', 'GET'])
async def _() -> str:
if quart.request.method == 'GET':
knowledge_bases = await rag.get_all_knowledge_bases()
bases_list = [
{
"uuid": kb.id,
"name": kb.name,
"description": kb.description,
} for kb in knowledge_bases
]
return self.success(code=0,
data={'bases': bases_list},
msg='ok')
json_data = await quart.request.json
knowledge_base_uuid = await rag.create_knowledge_base(
json_data.get('name'),
json_data.get('description')
)
return self.success()
@self.route('/<knowledge_base_uuid>', methods=['GET'])
async def _(knowledge_base_uuid: str) -> str:
if quart.request.method == 'GET':
knowledge_base = await rag.get_knowledge_base_by_id(knowledge_base_uuid)
if knowledge_base is None:
return self.http_status(404, -1, 'knowledge base not found')
return self.success(
code=0,
data={
"name": knowledge_base.name,
"description": knowledge_base.description,
"uuid": knowledge_base.id
},
msg='ok'
)
@self.route('/<knowledge_base_uuid>/files', methods=['GET'])
async def _(knowledge_base_uuid: str) -> str:
if quart.request.method == 'GET':
files = await rag.get_files_by_knowledge_base(knowledge_base_uuid)
return self.success(code=0,data=[{
"id": file.id,
"file_name": file.file_name,
"status": file.status
} for file in files],msg='ok')
# delete specific file in knowledge base
@self.route('/<knowledge_base_uuid>/files/<file_id>', methods=['DELETE'])
async def _(knowledge_base_uuid: str, file_id: str) -> str:
await rag.delete_data_by_file_id(file_id)
return self.success(code=0, msg='ok')
# delete specific kb
@self.route('/<knowledge_base_uuid>', methods=['DELETE'])
async def _(knowledge_base_uuid: str) -> str:
await rag.delete_kb_by_id(knowledge_base_uuid)
return self.success(code=0, msg='ok')

View File

@@ -27,6 +27,10 @@ from ..storage import mgr as storagemgr
from ..utils import logcache
from . import taskmgr
from . import entities as core_entities
from ...pkg.rag.knowledge import RAG_Manager
class Application:
@@ -99,6 +103,8 @@ class Application:
storage_mgr: storagemgr.StorageMgr = None
knowledge_base_service: RAG_Manager = None
# ========= HTTP Services =========
user_service: user_service.UserService = None
@@ -111,6 +117,7 @@ class Application:
bot_service: bot_service.BotService = None
def __init__(self):
pass

View File

@@ -0,0 +1,283 @@
# RAG_Manager class (main class, adjust imports as needed)
import logging
import os
import asyncio
from services.parser import FileParser
from services.chunker import Chunker
from services.embedder import Embedder
from services.retriever import Retriever
from services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly
from services.embedding_models import EmbeddingModelFactory
from services.chroma_manager import ChromaIndexManager
from ...core import app
class RAG_Manager:
def __init__(self, logger: logging.Logger = None):
self.logger = logger or logging.getLogger(__name__)
self.embedding_model_type = None
self.embedding_model_name = None
self.chroma_manager = None
self.parser = None
self.chunker = None
self.embedder = None
self.retriever = None
self.ap = app.Application
async def initialize_system(self):
await asyncio.to_thread(create_db_and_tables)
async def create_model(self, embedding_model_type: str,
embedding_model_name: str):
self.embedding_model_type = embedding_model_type
self.embedding_model_name = embedding_model_name
try:
model = EmbeddingModelFactory.create_model(
model_type=self.embedding_model_type,
model_name_key=self.embedding_model_name
)
self.logger.info(f"Configured embedding model '{self.embedding_model_name}' has dimension: {model.embedding_dimension}")
except Exception as e:
self.logger.critical(f"Failed to get dimension for configured embedding model '{self.embedding_model_name}': {e}")
raise RuntimeError("Failed to initialize RAG_Manager due to embedding model issues.")
self.chroma_manager = ChromaIndexManager(collection_name=f"rag_collection_{self.embedding_model_name.replace('-', '_')}")
self.parser = FileParser()
self.chunker = Chunker()
# Pass chroma_manager to Embedder and Retriever
self.embedder = Embedder(
model_type=self.embedding_model_type,
model_name_key=self.embedding_model_name,
chroma_manager=self.chroma_manager # Inject dependency
)
self.retriever = Retriever(
model_type=self.embedding_model_type,
model_name_key=self.embedding_model_name,
chroma_manager=self.chroma_manager # Inject dependency
)
async def create_knowledge_base(self, kb_name: str, kb_description: str):
"""
Creates a new knowledge base with the given name and description.
If a knowledge base with the same name already exists, it returns that one.
"""
try:
def _get_kb_sync(name):
session = SessionLocal()
try:
return session.query(KnowledgeBase).filter_by(name=name).first()
finally:
session.close()
kb = await asyncio.to_thread(_get_kb_sync, kb_name)
if not kb:
def _add_kb_sync():
session = SessionLocal()
try:
new_kb = KnowledgeBase(name=kb_name, description=kb_description)
session.add(new_kb)
session.commit()
session.refresh(new_kb)
return new_kb
finally:
session.close()
kb = await asyncio.to_thread(_add_kb_sync)
except Exception as e:
self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True)
raise
except Exception as e:
self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True)
raise
async def get_all_knowledge_bases(self):
"""
Retrieves all knowledge bases from the database.
"""
try:
def _get_all_kbs_sync():
session = SessionLocal()
try:
return session.query(KnowledgeBase).all()
finally:
session.close()
kbs = await asyncio.to_thread(_get_all_kbs_sync)
return kbs
except Exception as e:
self.logger.error(f"Error retrieving knowledge bases: {str(e)}", exc_info=True)
return []
async def get_knowledge_base_by_id(self, kb_id: int):
"""
Retrieves a knowledge base by its ID.
"""
try:
def _get_kb_sync(kb_id):
session = SessionLocal()
try:
return session.query(KnowledgeBase).filter_by(id=kb_id).first()
finally:
session.close()
kb = await asyncio.to_thread(_get_kb_sync, kb_id)
return kb
except Exception as e:
self.logger.error(f"Error retrieving knowledge base with ID {kb_id}: {str(e)}", exc_info=True)
return None
async def get_files_by_knowledge_base(self, kb_id: int):
try:
def _get_files_sync(kb_id):
session = SessionLocal()
try:
return session.query(File).filter_by(kb_id=kb_id).all()
finally:
session.close()
files = await asyncio.to_thread(_get_files_sync, kb_id)
return files
except Exception as e:
self.logger.error(f"Error retrieving files for knowledge base ID {kb_id}: {str(e)}", exc_info=True)
return []
async def store_data(self, file_path: str, kb_name: str, file_type: str, kb_description: str = "Default knowledge base"):
self.logger.info(f"Starting data storage process for file: {file_path}")
try:
def _get_kb_sync(name):
session = SessionLocal()
try:
return session.query(KnowledgeBase).filter_by(name=name).first()
finally:
session.close()
kb = await asyncio.to_thread(_get_kb_sync, kb_name)
if not kb:
self.logger.info(f"Knowledge Base '{kb_name}' not found. Creating a new one.")
def _add_kb_sync():
session = SessionLocal()
try:
new_kb = KnowledgeBase(name=kb_name, description=kb_description)
session.add(new_kb)
session.commit()
session.refresh(new_kb)
return new_kb
finally:
session.close()
kb = await asyncio.to_thread(_add_kb_sync)
self.logger.info(f"Created Knowledge Base: {kb.name} (ID: {kb.id})")
def _add_file_sync(kb_id, file_name, path, file_type):
session = SessionLocal()
try:
file = File(kb_id=kb_id, file_name=file_name, path=path, file_type=file_type)
session.add(file)
session.commit()
session.refresh(file)
return file
finally:
session.close()
file_obj = await asyncio.to_thread(_add_file_sync, kb.id, os.path.basename(file_path), file_path, file_type)
self.logger.info(f"Added file entry: {file_obj.file_name} (ID: {file_obj.id})")
text = await self.parser.parse(file_path)
if not text:
self.logger.warning(f"File {file_path} parsed to empty content. Skipping chunking and embedding.")
# You might want to delete the file_obj from the DB here if it's empty.
session = SessionLocal()
try:
session.delete(file_obj)
session.commit()
except Exception as del_e:
self.logger.error(f"Failed to delete empty file_obj {file_obj.id}: {del_e}")
finally:
session.close()
return
chunks_texts = await self.chunker.chunk(text)
self.logger.info(f"Chunked into {len(chunks_texts)} pieces.")
# embed_and_store now handles both DB chunk saving and Chroma embedding
await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts)
self.logger.info(f"Data storage process completed for file: {file_path}")
except Exception as e:
self.logger.error(f"Error in store_data for file {file_path}: {str(e)}", exc_info=True)
# Consider cleaning up partially stored data if an error occurs.
return
async def retrieve_data(self, query: str):
self.logger.info(f"Starting data retrieval process for query: '{query}'")
try:
retrieved_chunks = await self.retriever.retrieve(query)
self.logger.info(f"Successfully retrieved {len(retrieved_chunks)} chunks for query.")
return retrieved_chunks
except Exception as e:
self.logger.error(f"Error in retrieve_data for query '{query}': {str(e)}", exc_info=True)
return []
async def delete_data_by_file_id(self, file_id: int):
"""
Deletes data associated with a specific file_id from both the relational DB and Chroma.
"""
self.logger.info(f"Starting data deletion process for file_id: {file_id}")
session = SessionLocal()
try:
# 1. Delete from Chroma
await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id)
# 2. Delete chunks from relational DB
chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all()
for chunk in chunks_to_delete:
session.delete(chunk)
self.logger.info(f"Deleted {len(chunks_to_delete)} chunks from relational DB for file_id: {file_id}.")
# 3. Delete file entry from relational DB
file_to_delete = session.query(File).filter_by(id=file_id).first()
if file_to_delete:
session.delete(file_to_delete)
self.logger.info(f"Deleted file entry {file_id} from relational DB.")
else:
self.logger.warning(f"File entry {file_id} not found in relational DB.")
session.commit()
self.logger.info(f"Data deletion completed for file_id: {file_id}.")
except Exception as e:
session.rollback()
self.logger.error(f"Error deleting data for file_id {file_id}: {str(e)}", exc_info=True)
finally:
session.close()
async def delete_kb_by_id(self, kb_id: int):
"""
Deletes a knowledge base and all associated files and chunks.
"""
self.logger.info(f"Starting deletion of knowledge base with ID: {kb_id}")
session = SessionLocal()
try:
# 1. Get the knowledge base
kb = session.query(KnowledgeBase).filter_by(id=kb_id).first()
if not kb:
self.logger.warning(f"Knowledge Base with ID {kb_id} not found.")
return
# 2. Delete all files associated with this knowledge base
files_to_delete = session.query(File).filter_by(kb_id=kb.id).all()
for file in files_to_delete:
await self.delete_data_by_file_id(file.id)
# 3. Delete the knowledge base itself
session.delete(kb)
session.commit()
self.logger.info(f"Successfully deleted knowledge base with ID: {kb_id}")
except Exception as e:
session.rollback()
self.logger.error(f"Error deleting knowledge base with ID {kb_id}: {str(e)}", exc_info=True)
finally:
session.close()

View File

View File

@@ -0,0 +1,26 @@
# 封装异步操作
import asyncio
import logging
from services.database import SessionLocal # 导入 SessionLocal 工厂函数
class BaseService:
def __init__(self):
self.logger = logging.getLogger(self.__class__.__name__)
self.db_session_factory = SessionLocal # 使用 SessionLocal 工厂函数
async def _run_sync(self, func, *args, **kwargs):
"""
在单独的线程中运行同步函数。
如果第一个参数是 session则在 to_thread 中获取新的 session。
"""
# 如果函数需要数据库会话作为第一个参数,我们在这里获取它
if getattr(func, '__name__', '').startswith('_db_'): # 约定:数据库操作的同步方法以 _db_ 开头
session = await asyncio.to_thread(self.db_session_factory)
try:
result = await asyncio.to_thread(func, session, *args, **kwargs)
return result
finally:
session.close()
else:
# 否则,直接运行同步函数
return await asyncio.to_thread(func, *args, **kwargs)

View File

@@ -0,0 +1,65 @@
# services/chroma_manager.py
import numpy as np
import logging
from chromadb import PersistentClient
import os
logger = logging.getLogger(__name__)
class ChromaIndexManager:
def __init__(self, collection_name: str = "default_collection"):
self.logger = logging.getLogger(self.__class__.__name__)
chroma_data_path = "./chroma_data"
os.makedirs(chroma_data_path, exist_ok=True)
self.client = PersistentClient(path=chroma_data_path)
self._collection_name = collection_name
self._collection = None
self.logger.info(f"ChromaIndexManager initialized. Collection name: {self._collection_name}")
@property
def collection(self):
if self._collection is None:
self._collection = self.client.get_or_create_collection(name=self._collection_name)
self.logger.info(f"Chroma collection '{self._collection_name}' accessed/created.")
return self._collection
def add_embeddings_sync(self, file_ids: list[int], chunk_ids: list[int], embeddings: np.ndarray, documents: list[str]):
if embeddings.shape[0] != len(chunk_ids) or embeddings.shape[0] != len(file_ids) or embeddings.shape[0] != len(documents):
raise ValueError("Embedding, file_id, chunk_id, and document count mismatch.")
chroma_ids = [f"{file_id}_{chunk_id}" for file_id, chunk_id in zip(file_ids, chunk_ids)]
metadatas = [{"file_id": fid, "chunk_id": cid} for fid, cid in zip(file_ids, chunk_ids)]
self.logger.debug(f"Adding {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.")
self.collection.add(
embeddings=embeddings.tolist(),
ids=chroma_ids,
metadatas=metadatas,
documents=documents
)
self.logger.info(f"Added {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.")
def search_sync(self, query_embedding: np.ndarray, k: int = 5):
"""
Searches the Chroma collection for the top-k nearest neighbors.
Args:
query_embedding: A numpy array of the query embedding.
k: The number of results to return.
Returns:
A dictionary containing query results from Chroma.
"""
self.logger.debug(f"Searching Chroma collection '{self._collection_name}' with k={k}.")
results = self.collection.query(
query_embeddings=query_embedding.tolist(),
n_results=k,
# REMOVE 'ids' from the include list. It's returned by default.
include=["metadatas", "distances", "documents"]
)
self.logger.debug(f"Chroma search returned {len(results.get('ids', [[]])[0])} results.")
return results
def delete_by_file_id_sync(self, file_id: int):
self.logger.info(f"Deleting embeddings for file_id: {file_id} from Chroma collection '{self._collection_name}'.")
self.collection.delete(where={"file_id": file_id})
self.logger.info(f"Deleted embeddings for file_id: {file_id} from Chroma.")

View File

@@ -0,0 +1,63 @@
# services/chunker.py
import logging
from typing import List
from services.base_service import BaseService # Assuming BaseService provides _run_sync
logger = logging.getLogger(__name__)
class Chunker(BaseService):
"""
A class for splitting long texts into smaller, overlapping chunks.
"""
def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50):
super().__init__() # Initialize BaseService
self.logger = logging.getLogger(self.__class__.__name__)
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
if self.chunk_overlap >= self.chunk_size:
self.logger.warning("Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.")
def _split_text_sync(self, text: str) -> List[str]:
"""
Synchronously splits a long text into chunks with specified overlap.
This is a CPU-bound operation, intended to be run in a separate thread.
"""
if not text:
return []
# Simple whitespace-based splitting for demonstration
# For more advanced chunking, consider libraries like LangChain's text splitters
words = text.split()
chunks = []
current_chunk = []
for word in words:
current_chunk.append(word)
if len(current_chunk) > self.chunk_size:
chunks.append(" ".join(current_chunk[:self.chunk_size]))
current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:]
if current_chunk:
chunks.append(" ".join(current_chunk))
# A more robust chunking strategy (e.g., using recursive character text splitter)
# from langchain.text_splitter import RecursiveCharacterTextSplitter
# text_splitter = RecursiveCharacterTextSplitter(
# chunk_size=self.chunk_size,
# chunk_overlap=self.chunk_overlap,
# length_function=len,
# is_separator_regex=False,
# )
# return text_splitter.split_text(text)
return [chunk for chunk in chunks if chunk.strip()] # Filter out empty chunks
async def chunk(self, text: str) -> List[str]:
"""
Asynchronously chunks a given text into smaller pieces.
"""
self.logger.info(f"Chunking text (length: {len(text)})...")
# Run the synchronous splitting logic in a separate thread
chunks = await self._run_sync(self._split_text_sync, text)
self.logger.info(f"Text chunked into {len(chunks)} pieces.")
return chunks

View File

@@ -0,0 +1,57 @@
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary
from sqlalchemy.orm import declarative_base, sessionmaker, relationship
from datetime import datetime
import numpy as np # 用于处理从LargeBinary转换回来的embedding
Base = declarative_base()
class KnowledgeBase(Base):
__tablename__ = 'kb'
id = Column(Integer, primary_key=True, index=True)
name = Column(String, index=True)
description = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
files = relationship("File", back_populates="knowledge_base")
class File(Base):
__tablename__ = 'file'
id = Column(Integer, primary_key=True, index=True)
kb_id = Column(Integer, ForeignKey('kb.id'))
file_name = Column(String)
path = Column(String)
created_at = Column(DateTime, default=datetime.utcnow)
file_type = Column(String)
status = Column(Integer, default=0) # 0: 未处理, 1: 处理中, 2: 已处理, 3: 错误
knowledge_base = relationship("KnowledgeBase", back_populates="files")
chunks = relationship("Chunk", back_populates="file")
class Chunk(Base):
__tablename__ = 'chunks'
id = Column(Integer, primary_key=True, index=True)
file_id = Column(Integer, ForeignKey('file.id'))
text = Column(Text)
file = relationship("File", back_populates="chunks")
vector = relationship("Vector", uselist=False, back_populates="chunk") # One-to-one
class Vector(Base):
__tablename__ = 'vectors'
id = Column(Integer, primary_key=True, index=True)
chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True)
embedding = Column(LargeBinary) # Store embeddings as binary
chunk = relationship("Chunk", back_populates="vector")
# 数据库连接
DATABASE_URL = "sqlite:///./knowledge_base.db" # 生产环境请更换为 PostgreSQL/MySQL
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False} if "sqlite" in DATABASE_URL else {})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建所有表 (可以在应用启动时执行一次)
def create_db_and_tables():
Base.metadata.create_all(bind=engine)
print("Database tables created/checked.")
# 定义嵌入维度(请根据你实际使用的模型调整)
EMBEDDING_DIM = 1024

View File

@@ -0,0 +1,93 @@
# services/embedder.py
import asyncio
import logging
import numpy as np
from typing import List
from sqlalchemy.orm import Session
from services.base_service import BaseService
from services.database import Chunk, SessionLocal
from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
from services.chroma_manager import ChromaIndexManager # Import the manager
logger = logging.getLogger(__name__)
class Embedder(BaseService):
def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager):
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.model_type = model_type
self.model_name_key = model_name_key
self.chroma_manager = chroma_manager # Dependency Injection
self.embedding_model: BaseEmbeddingModel = self._load_embedding_model()
def _load_embedding_model(self) -> BaseEmbeddingModel:
self.logger.info(f"Loading embedding model: type={self.model_type}, name_key={self.model_name_key}...")
try:
model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key)
self.logger.info(f"Embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}")
return model
except Exception as e:
self.logger.error(f"Failed to load embedding model '{self.model_name_key}': {e}")
raise
def _db_save_chunks_sync(self, session: Session, file_id: int, chunks_texts: List[str]):
"""
Saves chunks to the relational database and returns the created Chunk objects.
This function assumes it's called within a context where the session
will be committed/rolled back and closed by the caller.
"""
self.logger.debug(f"Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).")
chunk_objects = []
for text in chunks_texts:
chunk = Chunk(file_id=file_id, text=text)
session.add(chunk)
chunk_objects.append(chunk)
session.flush() # This populates the .id attribute for each new chunk object
self.logger.debug(f"Successfully added {len(chunk_objects)} chunk entries to DB.")
return chunk_objects
async def embed_and_store(self, file_id: int, chunks: List[str]):
if not self.embedding_model:
raise RuntimeError("Embedding model not loaded. Please check Embedder initialization.")
self.logger.info(f"Embedding {len(chunks)} chunks for file_id: {file_id} using {self.model_name_key}...")
session = SessionLocal() # Start a session that will live for the whole operation
chunk_objects = []
try:
# 1. Save chunks to the relational database first to get their IDs
# We call _db_save_chunks_sync directly without _run_sync's session management
# because we manage the session here across multiple async calls.
chunk_objects = await asyncio.to_thread(self._db_save_chunks_sync, session, file_id, chunks)
session.commit() # Commit chunks to make their IDs permanent and accessible
if not chunk_objects:
self.logger.warning(f"No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.")
return []
# 2. Generate embeddings
embeddings: List[List[float]] = await self.embedding_model.embed_documents(chunks)
embeddings_np = np.array(embeddings, dtype=np.float32)
if embeddings_np.shape[1] != self.embedding_model.embedding_dimension:
self.logger.error(f"Mismatch in embedding dimension: Model returned {embeddings_np.shape[1]}, expected {self.embedding_model.embedding_dimension}. Aborting storage.")
raise ValueError("Embedding dimension mismatch during embedding process.")
self.logger.info("Saving embeddings to Chroma...")
chunk_ids = [c.id for c in chunk_objects] # Now safe to access .id because session is still open and committed
file_ids_for_chroma = [file_id] * len(chunk_ids)
await self._run_sync( # Use _run_sync for the Chroma operation, as it's a sync call
self.chroma_manager.add_embeddings_sync,
file_ids_for_chroma, chunk_ids, embeddings_np, chunks # Pass original chunks texts for documents
)
self.logger.info(f"Successfully saved {len(chunk_objects)} embeddings to Chroma.")
return chunk_objects
except Exception as e:
session.rollback() # Rollback on any error
self.logger.error(f"Failed to process and store data for file_id {file_id}: {e}", exc_info=True)
raise # Re-raise the exception to propagate it
finally:
session.close() # Ensure the session is always closed

View File

@@ -0,0 +1,223 @@
# services/embedding_models.py
import os
from typing import Dict, Any, List, Type, Optional
import logging
import aiohttp # Import aiohttp for asynchronous requests
import asyncio
from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
# Base class for all embedding models
class BaseEmbeddingModel:
def __init__(self, model_name: str):
self.model_name = model_name
self._embedding_dimension = None
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously embeds a list of texts."""
raise NotImplementedError
async def embed_query(self, text: str) -> List[float]:
"""Asynchronously embeds a single query text."""
raise NotImplementedError
@property
def embedding_dimension(self) -> int:
"""Returns the embedding dimension of the model."""
if self._embedding_dimension is None:
raise NotImplementedError("Embedding dimension not set for this model.")
return self._embedding_dimension
class EmbeddingModelFactory:
@staticmethod
def create_model(model_type: str, model_name_key: str) -> BaseEmbeddingModel:
"""
Factory method to create an embedding model instance.
Currently only supports 'third_party_api' types.
"""
if model_name_key not in EMBEDDING_MODEL_CONFIGS:
raise ValueError(f"Embedding model configuration '{model_name_key}' not found in EMBEDDING_MODEL_CONFIGS.")
config = EMBEDDING_MODEL_CONFIGS[model_name_key]
if config['type'] == "third_party_api":
required_keys = ['api_endpoint', 'headers', 'payload_template', 'embedding_dimension']
if not all(key in config for key in required_keys):
raise ValueError(f"Missing configuration keys for third_party_api model '{model_name_key}'. Required: {required_keys}")
# Retrieve model_name from config if it differs from model_name_key
# Some APIs expect a specific 'model' value in the payload that might be different from the key
api_model_name = config.get('model_name', model_name_key)
return ThirdPartyAPIEmbeddingModel(
model_name=api_model_name, # Use the model_name from config or the key
api_endpoint=config['api_endpoint'],
headers=config['headers'],
payload_template=config['payload_template'],
embedding_dimension=config['embedding_dimension']
)
class SentenceTransformerEmbeddingModel(BaseEmbeddingModel):
def __init__(self, model_name: str):
super().__init__(model_name)
try:
# SentenceTransformer is inherently synchronous, but we'll wrap its calls
# in async methods. The actual computation will still block the event loop
# if not run in a separate thread/process, but this keeps the API consistent.
self.model = SentenceTransformer(model_name)
self._embedding_dimension = self.model.get_sentence_embedding_dimension()
logger.info(f"Initialized SentenceTransformer model '{model_name}' with dimension {self._embedding_dimension}")
except Exception as e:
logger.error(f"Failed to load SentenceTransformer model {model_name}: {e}")
raise
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
# For CPU-bound tasks like local model inference, consider running in a thread pool
# to prevent blocking the event loop for long operations.
# For simplicity here, we'll call it directly.
return self.model.encode(texts).tolist()
async def embed_query(self, text: str) -> List[float]:
return self.model.encode(text).tolist()
class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel):
def __init__(self, model_name: str, api_endpoint: str, headers: Dict[str, str], payload_template: Dict[str, Any], embedding_dimension: int):
super().__init__(model_name)
self.api_endpoint = api_endpoint
self.headers = headers
self.payload_template = payload_template
self._embedding_dimension = embedding_dimension
self.session = None # aiohttp client session will be initialized on first use or in a context manager
logger.info(f"Initialized ThirdPartyAPIEmbeddingModel '{model_name}' for async calls to {api_endpoint} with dimension {embedding_dimension}")
async def _get_session(self):
"""Lazily create or return the aiohttp client session."""
if self.session is None or self.session.closed:
self.session = aiohttp.ClientSession()
return self.session
async def close_session(self):
"""Explicitly close the aiohttp client session."""
if self.session and not self.session.closed:
await self.session.close()
self.session = None
logger.info(f"Closed aiohttp session for model {self.model_name}")
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously embeds a list of texts using the third-party API."""
session = await self._get_session()
embeddings = []
tasks = []
for text in texts:
payload = self.payload_template.copy()
if 'input' in payload:
payload['input'] = text
elif 'texts' in payload:
payload['texts'] = [text]
else:
raise ValueError("Payload template does not contain expected text input key.")
tasks.append(self._make_api_request(session, payload))
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, res in enumerate(results):
if isinstance(res, Exception):
logger.error(f"Error embedding text '{texts[i][:50]}...': {res}")
# Depending on your error handling strategy, you might:
# - Append None or an empty list
# - Re-raise the exception to stop processing
# - Log and skip, then continue
embeddings.append([0.0] * self.embedding_dimension) # Append dummy embedding or handle failure
else:
embeddings.append(res)
return embeddings
async def _make_api_request(self, session: aiohttp.ClientSession, payload: Dict[str, Any]) -> List[float]:
"""Helper to make an asynchronous API request and extract embedding."""
try:
async with session.post(self.api_endpoint, headers=self.headers, json=payload) as response:
response.raise_for_status() # Raise an exception for HTTP errors (4xx, 5xx)
api_response = await response.json()
# Adjust this based on your API's actual response structure
if "data" in api_response and len(api_response["data"]) > 0 and "embedding" in api_response["data"][0]:
embedding = api_response["data"][0]["embedding"]
if len(embedding) != self.embedding_dimension:
logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.")
return embedding
elif "embeddings" in api_response and isinstance(api_response["embeddings"], list) and api_response["embeddings"]:
embedding = api_response["embeddings"][0]
if len(embedding) != self.embedding_dimension:
logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.")
return embedding
else:
raise ValueError(f"Unexpected API response structure: {api_response}")
except aiohttp.ClientError as e:
raise ConnectionError(f"API request failed: {e}") from e
except ValueError as e:
raise ValueError(f"Error processing API response: {e}") from e
async def embed_query(self, text: str) -> List[float]:
"""Asynchronously embeds a single query text."""
results = await self.embed_documents([text])
if results:
return results[0]
return [] # Or raise an error if embedding a query must always succeed
# --- Embedding Model Configuration ---
EMBEDDING_MODEL_CONFIGS: Dict[str, Dict[str, Any]] = {
"MiniLM": { # Example for a local Sentence Transformer model
"type": "sentence_transformer",
"model_name": "sentence-transformers/all-MiniLM-L6-v2"
},
"bge-m3": { # Example for a third-party API model
"type": "third_party_api",
"model_name": "bge-m3",
"api_endpoint": "https://api.qhaigc.net/v1/embeddings",
"headers": {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.getenv('rag_api_key')}"
},
"payload_template": {
"model": "bge-m3",
"input": ""
},
"embedding_dimension": 1024
},
"OpenAI-Ada-002": {
"type": "third_party_api",
"model_name": "text-embedding-ada-002",
"api_endpoint": "https://api.openai.com/v1/embeddings",
"headers": {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}" # Ensure OPENAI_API_KEY is set
},
"payload_template": {
"model": "text-embedding-ada-002",
"input": "" # Text will be injected here
},
"embedding_dimension": 1536
},
"OpenAI-Embedding-3-Small": {
"type": "third_party_api",
"model_name": "text-embedding-3-small",
"api_endpoint": "https://api.openai.com/v1/embeddings",
"headers": {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"
},
"payload_template": {
"model": "text-embedding-3-small",
"input": "",
# "dimensions": 512 # Optional: uncomment if you want a specific output dimension
},
"embedding_dimension": 1536 # Default max dimension for text-embedding-3-small
},
}

View File

@@ -0,0 +1,288 @@
import PyPDF2
from docx import Document
import pandas as pd
import csv
import chardet
from typing import Union, List, Callable, Any
import logging
import markdown
from bs4 import BeautifulSoup
import ebooklib
from ebooklib import epub
import re
import asyncio # Import asyncio for async operations
import os
# Configure logging
logger = logging.getLogger(__name__)
class FileParser:
"""
A robust file parser class to extract text content from various document formats.
It supports TXT, PDF, DOCX, XLSX, CSV, Markdown, HTML, and EPUB files.
All core file reading operations are designed to be run synchronously in a thread pool
to avoid blocking the asyncio event loop.
"""
def __init__(self):
self.logger = logging.getLogger(self.__class__.__name__)
async def _run_sync(self, sync_func: Callable, *args: Any, **kwargs: Any) -> Any:
"""
Runs a synchronous function in a separate thread to prevent blocking the event loop.
This is a general utility method for wrapping blocking I/O operations.
"""
try:
return await asyncio.to_thread(sync_func, *args, **kwargs)
except Exception as e:
self.logger.error(f"Error running synchronous function {sync_func.__name__}: {e}")
raise
async def parse(self, file_path: str) -> Union[str, None]:
"""
Parses the file based on its extension and returns the extracted text content.
This is the main asynchronous entry point for parsing.
Args:
file_path (str): The path to the file to be parsed.
Returns:
Union[str, None]: The extracted text content as a single string, or None if parsing fails.
"""
if not file_path or not os.path.exists(file_path):
self.logger.error(f"Invalid file path provided: {file_path}")
return None
file_extension = file_path.split('.')[-1].lower()
parser_method = getattr(self, f'_parse_{file_extension}', None)
if parser_method is None:
self.logger.error(f"Unsupported file format: {file_extension} for file {file_path}")
return None
try:
# Pass file_path to the specific parser methods
return await parser_method(file_path)
except Exception as e:
self.logger.error(f"Failed to parse {file_extension} file {file_path}: {e}")
return None
# --- Helper for reading files with encoding detection ---
async def _read_file_content(self, file_path: str, mode: str = 'r') -> Union[str, bytes]:
"""
Reads a file with automatic encoding detection, ensuring the synchronous
file read operation runs in a separate thread.
"""
def _read_sync():
with open(file_path, 'rb') as file:
raw_data = file.read()
detected = chardet.detect(raw_data)
encoding = detected['encoding'] or 'utf-8'
if mode == 'r':
return raw_data.decode(encoding, errors='ignore')
return raw_data # For binary mode
return await self._run_sync(_read_sync)
# --- Specific Parser Methods ---
async def _parse_txt(self, file_path: str) -> str:
"""Parses a TXT file and returns its content."""
self.logger.info(f"Parsing TXT file: {file_path}")
return await self._read_file_content(file_path, mode='r')
async def _parse_pdf(self, file_path: str) -> str:
"""Parses a PDF file and returns its text content."""
self.logger.info(f"Parsing PDF file: {file_path}")
def _parse_pdf_sync():
text_content = []
with open(file_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page in pdf_reader.pages:
text = page.extract_text()
if text:
text_content.append(text)
return '\n'.join(text_content)
return await self._run_sync(_parse_pdf_sync)
async def _parse_docx(self, file_path: str) -> str:
"""Parses a DOCX file and returns its text content."""
self.logger.info(f"Parsing DOCX file: {file_path}")
def _parse_docx_sync():
doc = Document(file_path)
text_content = [paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()]
return '\n'.join(text_content)
return await self._run_sync(_parse_docx_sync)
async def _parse_doc(self, file_path: str) -> str:
"""Handles .doc files, explicitly stating lack of direct support."""
self.logger.warning(f"Direct .doc parsing is not supported for {file_path}. Please convert to .docx first.")
raise NotImplementedError("Direct .doc parsing not supported. Please convert to .docx first.")
async def _parse_xlsx(self, file_path: str) -> str:
"""Parses an XLSX file, returning text from all sheets."""
self.logger.info(f"Parsing XLSX file: {file_path}")
def _parse_xlsx_sync():
excel_file = pd.ExcelFile(file_path)
all_sheet_content = []
for sheet_name in excel_file.sheet_names:
df = pd.read_excel(file_path, sheet_name=sheet_name)
sheet_text = f"--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n"
all_sheet_content.append(sheet_text)
return '\n'.join(all_sheet_content)
return await self._run_sync(_parse_xlsx_sync)
async def _parse_csv(self, file_path: str) -> str:
"""Parses a CSV file and returns its content as a string."""
self.logger.info(f"Parsing CSV file: {file_path}")
def _parse_csv_sync():
# pd.read_csv can often detect encoding, but explicit detection is safer
raw_data = self._read_file_content(file_path, mode='rb') # Note: this will need to be await outside this sync function
# For simplicity, we'll let pandas handle encoding internally after a raw read.
# A more robust solution might pass encoding directly to pd.read_csv after detection.
detected = chardet.detect(open(file_path, 'rb').read())
encoding = detected['encoding'] or 'utf-8'
df = pd.read_csv(file_path, encoding=encoding)
return df.to_string(index=False)
return await self._run_sync(_parse_csv_sync)
async def _parse_markdown(self, file_path: str) -> str:
"""Parses a Markdown file, converting it to structured plain text."""
self.logger.info(f"Parsing Markdown file: {file_path}")
def _parse_markdown_sync():
md_content = self._read_file_content(file_path, mode='r') # This is a synchronous call within a sync function
html_content = markdown.markdown(
md_content,
extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code']
)
soup = BeautifulSoup(html_content, 'html.parser')
text_parts = []
for element in soup.children:
if element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
level = int(element.name[1])
text_parts.append('#' * level + ' ' + element.get_text().strip())
elif element.name == 'p':
text = element.get_text().strip()
if text:
text_parts.append(text)
elif element.name in ['ul', 'ol']:
for li in element.find_all('li'):
text_parts.append(f"* {li.get_text().strip()}")
elif element.name == 'pre':
code_block = element.get_text().strip()
if code_block:
text_parts.append(f"```\n{code_block}\n```")
elif element.name == 'table':
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
if table_str:
text_parts.append(table_str)
elif element.name:
text = element.get_text(separator=' ', strip=True)
if text:
text_parts.append(text)
cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts))
return cleaned_text.strip()
return await self._run_sync(_parse_markdown_sync)
async def _parse_html(self, file_path: str) -> str:
"""Parses an HTML file, extracting structured plain text."""
self.logger.info(f"Parsing HTML file: {file_path}")
def _parse_html_sync():
html_content = self._read_file_content(file_path, mode='r') # Sync call within sync function
soup = BeautifulSoup(html_content, 'html.parser')
for script_or_style in soup(["script", "style"]):
script_or_style.decompose()
text_parts = []
for element in soup.body.children if soup.body else soup.children:
if element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']:
level = int(element.name[1])
text_parts.append('#' * level + ' ' + element.get_text().strip())
elif element.name == 'p':
text = element.get_text().strip()
if text:
text_parts.append(text)
elif element.name in ['ul', 'ol']:
for li in element.find_all('li'):
text = li.get_text().strip()
if text:
text_parts.append(f"* {text}")
elif element.name == 'table':
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
if table_str:
text_parts.append(table_str)
elif element.name:
text = element.get_text(separator=' ', strip=True)
if text:
text_parts.append(text)
cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts))
return cleaned_text.strip()
return await self._run_sync(_parse_html_sync)
async def _parse_epub(self, file_path: str) -> str:
"""Parses an EPUB file, extracting metadata and content."""
self.logger.info(f"Parsing EPUB file: {file_path}")
def _parse_epub_sync():
book = epub.read_epub(file_path)
text_content = []
title_meta = book.get_metadata('DC', 'title')
if title_meta:
text_content.append(f"Title: {title_meta[0][0]}")
creator_meta = book.get_metadata('DC', 'creator')
if creator_meta:
text_content.append(f"Author: {creator_meta[0][0]}")
date_meta = book.get_metadata('DC', 'date')
if date_meta:
text_content.append(f"Publish Date: {date_meta[0][0]}")
toc = book.get_toc()
if toc:
text_content.append("\n--- Table of Contents ---")
self._add_toc_items_sync(toc, text_content, level=0) # Call sync helper
text_content.append("--- End of Table of Contents ---\n")
for item in book.get_items():
if item.get_type() == ebooklib.ITEM_DOCUMENT:
html_content = item.get_content().decode('utf-8', errors='ignore')
soup = BeautifulSoup(html_content, 'html.parser')
for junk in soup(["script", "style", "nav", "header", "footer"]):
junk.decompose()
text = soup.get_text(separator='\n', strip=True)
text = re.sub(r'\n\s*\n', '\n\n', text)
if text:
text_content.append(text)
return re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_content)).strip()
return await self._run_sync(_parse_epub_sync)
def _add_toc_items_sync(self, toc_list: list, text_content: list, level: int):
"""Recursively adds TOC items to text_content (synchronous helper)."""
indent = ' ' * level
for item in toc_list:
if isinstance(item, tuple):
chapter, subchapters = item
text_content.append(f"{indent}- {chapter.title}")
self._add_toc_items_sync(subchapters, text_content, level + 1)
else:
text_content.append(f"{indent}- {item.title}")
def _extract_table_to_markdown_sync(self, table_element: BeautifulSoup) -> str:
"""Helper to convert a BeautifulSoup table element into a Markdown table string (synchronous)."""
headers = [th.get_text().strip() for th in table_element.find_all('th')]
rows = []
for tr in table_element.find_all('tr'):
cells = [td.get_text().strip() for td in tr.find_all('td')]
if cells:
rows.append(cells)
if not headers and not rows:
return ""
table_lines = []
if headers:
table_lines.append(' | '.join(headers))
table_lines.append(' | '.join(['---'] * len(headers)))
for row_cells in rows:
padded_cells = row_cells + [''] * (len(headers) - len(row_cells)) if headers else row_cells
table_lines.append(' | '.join(padded_cells))
return '\n'.join(table_lines)

View File

@@ -0,0 +1,106 @@
# services/retriever.py
import asyncio
import logging
import numpy as np # Make sure numpy is imported
from typing import List, Dict, Any
from sqlalchemy.orm import Session
from services.base_service import BaseService
from services.database import Chunk, SessionLocal
from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
from services.chroma_manager import ChromaIndexManager
logger = logging.getLogger(__name__)
class Retriever(BaseService):
def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager):
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.model_type = model_type
self.model_name_key = model_name_key
self.chroma_manager = chroma_manager
self.embedding_model: BaseEmbeddingModel = self._load_embedding_model()
def _load_embedding_model(self) -> BaseEmbeddingModel:
self.logger.info(f"Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...")
try:
model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key)
self.logger.info(f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}")
return model
except Exception as e:
self.logger.error(f"Failed to load retriever embedding model '{self.model_name_key}': {e}")
raise
async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
if not self.embedding_model:
raise RuntimeError("Retriever embedding model not loaded. Please check Retriever initialization.")
self.logger.info(f"Retrieving for query: '{query}' with k={k} using {self.model_name_key}")
query_embedding: List[float] = await self.embedding_model.embed_query(query)
query_embedding_np = np.array([query_embedding], dtype=np.float32)
chroma_results = await self._run_sync(
self.chroma_manager.search_sync,
query_embedding_np, k
)
# 'ids' is always returned by ChromaDB, even if not explicitly in 'include'
matched_chroma_ids = chroma_results.get("ids", [[]])[0]
distances = chroma_results.get("distances", [[]])[0]
chroma_metadatas = chroma_results.get("metadatas", [[]])[0]
chroma_documents = chroma_results.get("documents", [[]])[0]
if not matched_chroma_ids:
self.logger.info("No relevant chunks found in Chroma.")
return []
db_chunk_ids = []
for metadata in chroma_metadatas:
if "chunk_id" in metadata:
db_chunk_ids.append(metadata["chunk_id"])
else:
self.logger.warning(f"Metadata missing 'chunk_id': {metadata}. Skipping this entry.")
if not db_chunk_ids:
self.logger.warning("No valid chunk_ids extracted from Chroma results metadata.")
return []
self.logger.info(f"Fetching {len(db_chunk_ids)} chunk details from relational database...")
chunks_from_db = await self._run_sync(
lambda cids: self._db_get_chunks_sync(SessionLocal(), cids), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync
db_chunk_ids
)
chunk_map = {chunk.id: chunk for chunk in chunks_from_db}
results_list: List[Dict[str, Any]] = []
for i, chroma_id in enumerate(matched_chroma_ids):
try:
# Ensure original_chunk_id is int for DB lookup
original_chunk_id = int(chroma_id.split('_')[-1])
except (ValueError, IndexError):
self.logger.warning(f"Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.")
continue
chunk_text_from_chroma = chroma_documents[i]
distance = float(distances[i])
file_id_from_chroma = chroma_metadatas[i].get("file_id")
chunk_from_db = chunk_map.get(original_chunk_id)
results_list.append({
"chunk_id": original_chunk_id,
"text": chunk_from_db.text if chunk_from_db else chunk_text_from_chroma,
"distance": distance,
"file_id": file_id_from_chroma
})
self.logger.info(f"Retrieved {len(results_list)} chunks.")
return results_list
def _db_get_chunks_sync(self, session: Session, chunk_ids: List[int]) -> List[Chunk]:
self.logger.debug(f"Fetching {len(chunk_ids)} chunk details from database (sync).")
chunks = session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all()
session.close()
return chunks

View File

View File

@@ -49,6 +49,17 @@ dependencies = [
"ruff>=0.11.9",
"pre-commit>=4.2.0",
"uv>=0.7.11",
"PyPDF2>=3.0.1",
"python-docx>=1.1.0",
"pandas>=2.2.2",
"chardet>=5.2.0",
"markdown>=3.6",
"beautifulsoup4>=4.12.3",
"ebooklib>=0.18",
"html2text>=2024.2.26",
"langchain>=0.2.0",
"chromadb>=0.4.24",
"sentence-transformers>=2.6.1",
]
keywords = [
"bot",