diff --git a/pkg/api/http/controller/groups/knowledge_base.py b/pkg/api/http/controller/groups/knowledge_base.py new file mode 100644 index 00000000..c819397a --- /dev/null +++ b/pkg/api/http/controller/groups/knowledge_base.py @@ -0,0 +1,83 @@ +import quart +from __future__ import annotations +from .. import group + +@group.group_class('knowledge_base', '/api/v1/knowledge/bases') +class KnowledgeBaseRouterGroup(group.RouterGroup): + + # 定义成功方法 + def success(self, code=0, data=None, msg: str = 'ok') -> quart.Response: + return quart.jsonify({ + "code": code, + "data": data or {}, + "msg": msg + }) + + + + async def initialize(self) -> None: + rag = self.ap.knowledge_base_service.RAG_Manager() + + @self.route('', methods=['POST', 'GET']) + async def _() -> str: + + if quart.request.method == 'GET': + knowledge_bases = await rag.get_all_knowledge_bases() + bases_list = [ + { + "uuid": kb.id, + "name": kb.name, + "description": kb.description, + } for kb in knowledge_bases + ] + return self.success(code=0, + data={'bases': bases_list}, + msg='ok') + + json_data = await quart.request.json + knowledge_base_uuid = await rag.create_knowledge_base( + json_data.get('name'), + json_data.get('description') + ) + return self.success() + + + @self.route('/', methods=['GET']) + async def _(knowledge_base_uuid: str) -> str: + if quart.request.method == 'GET': + knowledge_base = await rag.get_knowledge_base_by_id(knowledge_base_uuid) + + if knowledge_base is None: + return self.http_status(404, -1, 'knowledge base not found') + + return self.success( + code=0, + data={ + "name": knowledge_base.name, + "description": knowledge_base.description, + "uuid": knowledge_base.id + }, + msg='ok' + ) + + @self.route('//files', methods=['GET']) + async def _(knowledge_base_uuid: str) -> str: + if quart.request.method == 'GET': + files = await rag.get_files_by_knowledge_base(knowledge_base_uuid) + return self.success(code=0,data=[{ + "id": file.id, + "file_name": file.file_name, + "status": file.status + } for file in files],msg='ok') + + # delete specific file in knowledge base + @self.route('//files/', methods=['DELETE']) + async def _(knowledge_base_uuid: str, file_id: str) -> str: + await rag.delete_data_by_file_id(file_id) + return self.success(code=0, msg='ok') + + # delete specific kb + @self.route('/', methods=['DELETE']) + async def _(knowledge_base_uuid: str) -> str: + await rag.delete_kb_by_id(knowledge_base_uuid) + return self.success(code=0, msg='ok') diff --git a/pkg/core/app.py b/pkg/core/app.py index 318cddcb..d8824466 100644 --- a/pkg/core/app.py +++ b/pkg/core/app.py @@ -27,6 +27,10 @@ from ..storage import mgr as storagemgr from ..utils import logcache from . import taskmgr from . import entities as core_entities +from ...pkg.rag.knowledge import RAG_Manager + + + class Application: @@ -99,6 +103,8 @@ class Application: storage_mgr: storagemgr.StorageMgr = None + knowledge_base_service: RAG_Manager = None + # ========= HTTP Services ========= user_service: user_service.UserService = None @@ -111,6 +117,7 @@ class Application: bot_service: bot_service.BotService = None + def __init__(self): pass diff --git a/pkg/rag/knowledge/RAG_Manager.py b/pkg/rag/knowledge/RAG_Manager.py new file mode 100644 index 00000000..e172c132 --- /dev/null +++ b/pkg/rag/knowledge/RAG_Manager.py @@ -0,0 +1,283 @@ +# RAG_Manager class (main class, adjust imports as needed) +import logging +import os +import asyncio +from services.parser import FileParser +from services.chunker import Chunker +from services.embedder import Embedder +from services.retriever import Retriever +from services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk # Ensure Chunk is imported if you need to manipulate it directly +from services.embedding_models import EmbeddingModelFactory +from services.chroma_manager import ChromaIndexManager +from ...core import app + +class RAG_Manager: + def __init__(self, logger: logging.Logger = None): + self.logger = logger or logging.getLogger(__name__) + self.embedding_model_type = None + self.embedding_model_name = None + self.chroma_manager = None + self.parser = None + self.chunker = None + self.embedder = None + self.retriever = None + self.ap = app.Application + + async def initialize_system(self): + await asyncio.to_thread(create_db_and_tables) + + async def create_model(self, embedding_model_type: str, + embedding_model_name: str): + self.embedding_model_type = embedding_model_type + self.embedding_model_name = embedding_model_name + + try: + model = EmbeddingModelFactory.create_model( + model_type=self.embedding_model_type, + model_name_key=self.embedding_model_name + ) + self.logger.info(f"Configured embedding model '{self.embedding_model_name}' has dimension: {model.embedding_dimension}") + except Exception as e: + self.logger.critical(f"Failed to get dimension for configured embedding model '{self.embedding_model_name}': {e}") + raise RuntimeError("Failed to initialize RAG_Manager due to embedding model issues.") + + self.chroma_manager = ChromaIndexManager(collection_name=f"rag_collection_{self.embedding_model_name.replace('-', '_')}") + + self.parser = FileParser() + self.chunker = Chunker() + # Pass chroma_manager to Embedder and Retriever + self.embedder = Embedder( + model_type=self.embedding_model_type, + model_name_key=self.embedding_model_name, + chroma_manager=self.chroma_manager # Inject dependency + ) + self.retriever = Retriever( + model_type=self.embedding_model_type, + model_name_key=self.embedding_model_name, + chroma_manager=self.chroma_manager # Inject dependency + ) + + + async def create_knowledge_base(self, kb_name: str, kb_description: str): + """ + Creates a new knowledge base with the given name and description. + If a knowledge base with the same name already exists, it returns that one. + """ + try: + def _get_kb_sync(name): + session = SessionLocal() + try: + return session.query(KnowledgeBase).filter_by(name=name).first() + finally: + session.close() + + kb = await asyncio.to_thread(_get_kb_sync, kb_name) + + if not kb: + def _add_kb_sync(): + session = SessionLocal() + try: + new_kb = KnowledgeBase(name=kb_name, description=kb_description) + session.add(new_kb) + session.commit() + session.refresh(new_kb) + return new_kb + finally: + session.close() + kb = await asyncio.to_thread(_add_kb_sync) + except Exception as e: + self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True) + raise + except Exception as e: + self.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True) + raise + + async def get_all_knowledge_bases(self): + """ + Retrieves all knowledge bases from the database. + """ + try: + def _get_all_kbs_sync(): + session = SessionLocal() + try: + return session.query(KnowledgeBase).all() + finally: + session.close() + + kbs = await asyncio.to_thread(_get_all_kbs_sync) + return kbs + except Exception as e: + self.logger.error(f"Error retrieving knowledge bases: {str(e)}", exc_info=True) + return [] + + async def get_knowledge_base_by_id(self, kb_id: int): + """ + Retrieves a knowledge base by its ID. + """ + try: + def _get_kb_sync(kb_id): + session = SessionLocal() + try: + return session.query(KnowledgeBase).filter_by(id=kb_id).first() + finally: + session.close() + + kb = await asyncio.to_thread(_get_kb_sync, kb_id) + return kb + except Exception as e: + self.logger.error(f"Error retrieving knowledge base with ID {kb_id}: {str(e)}", exc_info=True) + return None + + async def get_files_by_knowledge_base(self, kb_id: int): + try: + def _get_files_sync(kb_id): + session = SessionLocal() + try: + return session.query(File).filter_by(kb_id=kb_id).all() + finally: + session.close() + + files = await asyncio.to_thread(_get_files_sync, kb_id) + return files + except Exception as e: + self.logger.error(f"Error retrieving files for knowledge base ID {kb_id}: {str(e)}", exc_info=True) + return [] + + + async def store_data(self, file_path: str, kb_name: str, file_type: str, kb_description: str = "Default knowledge base"): + self.logger.info(f"Starting data storage process for file: {file_path}") + try: + def _get_kb_sync(name): + session = SessionLocal() + try: + return session.query(KnowledgeBase).filter_by(name=name).first() + finally: + session.close() + + kb = await asyncio.to_thread(_get_kb_sync, kb_name) + + if not kb: + self.logger.info(f"Knowledge Base '{kb_name}' not found. Creating a new one.") + def _add_kb_sync(): + session = SessionLocal() + try: + new_kb = KnowledgeBase(name=kb_name, description=kb_description) + session.add(new_kb) + session.commit() + session.refresh(new_kb) + return new_kb + finally: + session.close() + kb = await asyncio.to_thread(_add_kb_sync) + self.logger.info(f"Created Knowledge Base: {kb.name} (ID: {kb.id})") + + def _add_file_sync(kb_id, file_name, path, file_type): + session = SessionLocal() + try: + file = File(kb_id=kb_id, file_name=file_name, path=path, file_type=file_type) + session.add(file) + session.commit() + session.refresh(file) + return file + finally: + session.close() + + file_obj = await asyncio.to_thread(_add_file_sync, kb.id, os.path.basename(file_path), file_path, file_type) + self.logger.info(f"Added file entry: {file_obj.file_name} (ID: {file_obj.id})") + + text = await self.parser.parse(file_path) + if not text: + self.logger.warning(f"File {file_path} parsed to empty content. Skipping chunking and embedding.") + # You might want to delete the file_obj from the DB here if it's empty. + session = SessionLocal() + try: + session.delete(file_obj) + session.commit() + except Exception as del_e: + self.logger.error(f"Failed to delete empty file_obj {file_obj.id}: {del_e}") + finally: + session.close() + return + + chunks_texts = await self.chunker.chunk(text) + self.logger.info(f"Chunked into {len(chunks_texts)} pieces.") + + # embed_and_store now handles both DB chunk saving and Chroma embedding + await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts) + + self.logger.info(f"Data storage process completed for file: {file_path}") + + except Exception as e: + self.logger.error(f"Error in store_data for file {file_path}: {str(e)}", exc_info=True) + # Consider cleaning up partially stored data if an error occurs. + return + + async def retrieve_data(self, query: str): + self.logger.info(f"Starting data retrieval process for query: '{query}'") + try: + retrieved_chunks = await self.retriever.retrieve(query) + self.logger.info(f"Successfully retrieved {len(retrieved_chunks)} chunks for query.") + return retrieved_chunks + except Exception as e: + self.logger.error(f"Error in retrieve_data for query '{query}': {str(e)}", exc_info=True) + return [] + + async def delete_data_by_file_id(self, file_id: int): + """ + Deletes data associated with a specific file_id from both the relational DB and Chroma. + """ + self.logger.info(f"Starting data deletion process for file_id: {file_id}") + session = SessionLocal() + try: + # 1. Delete from Chroma + await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id) + + # 2. Delete chunks from relational DB + chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all() + for chunk in chunks_to_delete: + session.delete(chunk) + self.logger.info(f"Deleted {len(chunks_to_delete)} chunks from relational DB for file_id: {file_id}.") + + # 3. Delete file entry from relational DB + file_to_delete = session.query(File).filter_by(id=file_id).first() + if file_to_delete: + session.delete(file_to_delete) + self.logger.info(f"Deleted file entry {file_id} from relational DB.") + else: + self.logger.warning(f"File entry {file_id} not found in relational DB.") + + session.commit() + self.logger.info(f"Data deletion completed for file_id: {file_id}.") + except Exception as e: + session.rollback() + self.logger.error(f"Error deleting data for file_id {file_id}: {str(e)}", exc_info=True) + finally: + session.close() + + async def delete_kb_by_id(self, kb_id: int): + """ + Deletes a knowledge base and all associated files and chunks. + """ + self.logger.info(f"Starting deletion of knowledge base with ID: {kb_id}") + session = SessionLocal() + try: + # 1. Get the knowledge base + kb = session.query(KnowledgeBase).filter_by(id=kb_id).first() + if not kb: + self.logger.warning(f"Knowledge Base with ID {kb_id} not found.") + return + + # 2. Delete all files associated with this knowledge base + files_to_delete = session.query(File).filter_by(kb_id=kb.id).all() + for file in files_to_delete: + await self.delete_data_by_file_id(file.id) + + # 3. Delete the knowledge base itself + session.delete(kb) + session.commit() + self.logger.info(f"Successfully deleted knowledge base with ID: {kb_id}") + except Exception as e: + session.rollback() + self.logger.error(f"Error deleting knowledge base with ID {kb_id}: {str(e)}", exc_info=True) + finally: + session.close() diff --git a/pkg/rag/knowledge/services/__init__.py b/pkg/rag/knowledge/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pkg/rag/knowledge/services/base_service.py b/pkg/rag/knowledge/services/base_service.py new file mode 100644 index 00000000..0298226a --- /dev/null +++ b/pkg/rag/knowledge/services/base_service.py @@ -0,0 +1,26 @@ +# 封装异步操作 +import asyncio +import logging +from services.database import SessionLocal # 导入 SessionLocal 工厂函数 + +class BaseService: + def __init__(self): + self.logger = logging.getLogger(self.__class__.__name__) + self.db_session_factory = SessionLocal # 使用 SessionLocal 工厂函数 + + async def _run_sync(self, func, *args, **kwargs): + """ + 在单独的线程中运行同步函数。 + 如果第一个参数是 session,则在 to_thread 中获取新的 session。 + """ + # 如果函数需要数据库会话作为第一个参数,我们在这里获取它 + if getattr(func, '__name__', '').startswith('_db_'): # 约定:数据库操作的同步方法以 _db_ 开头 + session = await asyncio.to_thread(self.db_session_factory) + try: + result = await asyncio.to_thread(func, session, *args, **kwargs) + return result + finally: + session.close() + else: + # 否则,直接运行同步函数 + return await asyncio.to_thread(func, *args, **kwargs) \ No newline at end of file diff --git a/pkg/rag/knowledge/services/chroma_manager.py b/pkg/rag/knowledge/services/chroma_manager.py new file mode 100644 index 00000000..6a469168 --- /dev/null +++ b/pkg/rag/knowledge/services/chroma_manager.py @@ -0,0 +1,65 @@ +# services/chroma_manager.py +import numpy as np +import logging +from chromadb import PersistentClient +import os + +logger = logging.getLogger(__name__) + +class ChromaIndexManager: + def __init__(self, collection_name: str = "default_collection"): + self.logger = logging.getLogger(self.__class__.__name__) + chroma_data_path = "./chroma_data" + os.makedirs(chroma_data_path, exist_ok=True) + self.client = PersistentClient(path=chroma_data_path) + self._collection_name = collection_name + self._collection = None + + self.logger.info(f"ChromaIndexManager initialized. Collection name: {self._collection_name}") + + @property + def collection(self): + if self._collection is None: + self._collection = self.client.get_or_create_collection(name=self._collection_name) + self.logger.info(f"Chroma collection '{self._collection_name}' accessed/created.") + return self._collection + + def add_embeddings_sync(self, file_ids: list[int], chunk_ids: list[int], embeddings: np.ndarray, documents: list[str]): + if embeddings.shape[0] != len(chunk_ids) or embeddings.shape[0] != len(file_ids) or embeddings.shape[0] != len(documents): + raise ValueError("Embedding, file_id, chunk_id, and document count mismatch.") + + chroma_ids = [f"{file_id}_{chunk_id}" for file_id, chunk_id in zip(file_ids, chunk_ids)] + metadatas = [{"file_id": fid, "chunk_id": cid} for fid, cid in zip(file_ids, chunk_ids)] + + self.logger.debug(f"Adding {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.") + self.collection.add( + embeddings=embeddings.tolist(), + ids=chroma_ids, + metadatas=metadatas, + documents=documents + ) + self.logger.info(f"Added {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.") + + def search_sync(self, query_embedding: np.ndarray, k: int = 5): + """ + Searches the Chroma collection for the top-k nearest neighbors. + Args: + query_embedding: A numpy array of the query embedding. + k: The number of results to return. + Returns: + A dictionary containing query results from Chroma. + """ + self.logger.debug(f"Searching Chroma collection '{self._collection_name}' with k={k}.") + results = self.collection.query( + query_embeddings=query_embedding.tolist(), + n_results=k, + # REMOVE 'ids' from the include list. It's returned by default. + include=["metadatas", "distances", "documents"] + ) + self.logger.debug(f"Chroma search returned {len(results.get('ids', [[]])[0])} results.") + return results + + def delete_by_file_id_sync(self, file_id: int): + self.logger.info(f"Deleting embeddings for file_id: {file_id} from Chroma collection '{self._collection_name}'.") + self.collection.delete(where={"file_id": file_id}) + self.logger.info(f"Deleted embeddings for file_id: {file_id} from Chroma.") \ No newline at end of file diff --git a/pkg/rag/knowledge/services/chunker.py b/pkg/rag/knowledge/services/chunker.py new file mode 100644 index 00000000..f115dac4 --- /dev/null +++ b/pkg/rag/knowledge/services/chunker.py @@ -0,0 +1,63 @@ +# services/chunker.py +import logging +from typing import List +from services.base_service import BaseService # Assuming BaseService provides _run_sync + +logger = logging.getLogger(__name__) + +class Chunker(BaseService): + """ + A class for splitting long texts into smaller, overlapping chunks. + """ + def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50): + super().__init__() # Initialize BaseService + self.logger = logging.getLogger(self.__class__.__name__) + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + if self.chunk_overlap >= self.chunk_size: + self.logger.warning("Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.") + + def _split_text_sync(self, text: str) -> List[str]: + """ + Synchronously splits a long text into chunks with specified overlap. + This is a CPU-bound operation, intended to be run in a separate thread. + """ + if not text: + return [] + + # Simple whitespace-based splitting for demonstration + # For more advanced chunking, consider libraries like LangChain's text splitters + words = text.split() + chunks = [] + current_chunk = [] + + for word in words: + current_chunk.append(word) + if len(current_chunk) > self.chunk_size: + chunks.append(" ".join(current_chunk[:self.chunk_size])) + current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:] + + if current_chunk: + chunks.append(" ".join(current_chunk)) + + # A more robust chunking strategy (e.g., using recursive character text splitter) + # from langchain.text_splitter import RecursiveCharacterTextSplitter + # text_splitter = RecursiveCharacterTextSplitter( + # chunk_size=self.chunk_size, + # chunk_overlap=self.chunk_overlap, + # length_function=len, + # is_separator_regex=False, + # ) + # return text_splitter.split_text(text) + + return [chunk for chunk in chunks if chunk.strip()] # Filter out empty chunks + + async def chunk(self, text: str) -> List[str]: + """ + Asynchronously chunks a given text into smaller pieces. + """ + self.logger.info(f"Chunking text (length: {len(text)})...") + # Run the synchronous splitting logic in a separate thread + chunks = await self._run_sync(self._split_text_sync, text) + self.logger.info(f"Text chunked into {len(chunks)} pieces.") + return chunks \ No newline at end of file diff --git a/pkg/rag/knowledge/services/database.py b/pkg/rag/knowledge/services/database.py new file mode 100644 index 00000000..4ec21af3 --- /dev/null +++ b/pkg/rag/knowledge/services/database.py @@ -0,0 +1,57 @@ +from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary +from sqlalchemy.orm import declarative_base, sessionmaker, relationship +from datetime import datetime +import numpy as np # 用于处理从LargeBinary转换回来的embedding + +Base = declarative_base() + +class KnowledgeBase(Base): + __tablename__ = 'kb' + id = Column(Integer, primary_key=True, index=True) + name = Column(String, index=True) + description = Column(Text) + created_at = Column(DateTime, default=datetime.utcnow) + + files = relationship("File", back_populates="knowledge_base") + +class File(Base): + __tablename__ = 'file' + id = Column(Integer, primary_key=True, index=True) + kb_id = Column(Integer, ForeignKey('kb.id')) + file_name = Column(String) + path = Column(String) + created_at = Column(DateTime, default=datetime.utcnow) + file_type = Column(String) + status = Column(Integer, default=0) # 0: 未处理, 1: 处理中, 2: 已处理, 3: 错误 + knowledge_base = relationship("KnowledgeBase", back_populates="files") + chunks = relationship("Chunk", back_populates="file") + +class Chunk(Base): + __tablename__ = 'chunks' + id = Column(Integer, primary_key=True, index=True) + file_id = Column(Integer, ForeignKey('file.id')) + text = Column(Text) + + file = relationship("File", back_populates="chunks") + vector = relationship("Vector", uselist=False, back_populates="chunk") # One-to-one + +class Vector(Base): + __tablename__ = 'vectors' + id = Column(Integer, primary_key=True, index=True) + chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True) + embedding = Column(LargeBinary) # Store embeddings as binary + + chunk = relationship("Chunk", back_populates="vector") + +# 数据库连接 +DATABASE_URL = "sqlite:///./knowledge_base.db" # 生产环境请更换为 PostgreSQL/MySQL +engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False} if "sqlite" in DATABASE_URL else {}) +SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) + +# 创建所有表 (可以在应用启动时执行一次) +def create_db_and_tables(): + Base.metadata.create_all(bind=engine) + print("Database tables created/checked.") + +# 定义嵌入维度(请根据你实际使用的模型调整) +EMBEDDING_DIM = 1024 \ No newline at end of file diff --git a/pkg/rag/knowledge/services/embedder.py b/pkg/rag/knowledge/services/embedder.py new file mode 100644 index 00000000..2b581e96 --- /dev/null +++ b/pkg/rag/knowledge/services/embedder.py @@ -0,0 +1,93 @@ +# services/embedder.py +import asyncio +import logging +import numpy as np +from typing import List +from sqlalchemy.orm import Session +from services.base_service import BaseService +from services.database import Chunk, SessionLocal +from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory +from services.chroma_manager import ChromaIndexManager # Import the manager + +logger = logging.getLogger(__name__) + +class Embedder(BaseService): + def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager): + super().__init__() + self.logger = logging.getLogger(self.__class__.__name__) + self.model_type = model_type + self.model_name_key = model_name_key + self.chroma_manager = chroma_manager # Dependency Injection + + self.embedding_model: BaseEmbeddingModel = self._load_embedding_model() + + def _load_embedding_model(self) -> BaseEmbeddingModel: + self.logger.info(f"Loading embedding model: type={self.model_type}, name_key={self.model_name_key}...") + try: + model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key) + self.logger.info(f"Embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}") + return model + except Exception as e: + self.logger.error(f"Failed to load embedding model '{self.model_name_key}': {e}") + raise + + def _db_save_chunks_sync(self, session: Session, file_id: int, chunks_texts: List[str]): + """ + Saves chunks to the relational database and returns the created Chunk objects. + This function assumes it's called within a context where the session + will be committed/rolled back and closed by the caller. + """ + self.logger.debug(f"Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).") + chunk_objects = [] + for text in chunks_texts: + chunk = Chunk(file_id=file_id, text=text) + session.add(chunk) + chunk_objects.append(chunk) + session.flush() # This populates the .id attribute for each new chunk object + self.logger.debug(f"Successfully added {len(chunk_objects)} chunk entries to DB.") + return chunk_objects + + async def embed_and_store(self, file_id: int, chunks: List[str]): + if not self.embedding_model: + raise RuntimeError("Embedding model not loaded. Please check Embedder initialization.") + + self.logger.info(f"Embedding {len(chunks)} chunks for file_id: {file_id} using {self.model_name_key}...") + + session = SessionLocal() # Start a session that will live for the whole operation + chunk_objects = [] + try: + # 1. Save chunks to the relational database first to get their IDs + # We call _db_save_chunks_sync directly without _run_sync's session management + # because we manage the session here across multiple async calls. + chunk_objects = await asyncio.to_thread(self._db_save_chunks_sync, session, file_id, chunks) + session.commit() # Commit chunks to make their IDs permanent and accessible + + if not chunk_objects: + self.logger.warning(f"No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.") + return [] + + # 2. Generate embeddings + embeddings: List[List[float]] = await self.embedding_model.embed_documents(chunks) + embeddings_np = np.array(embeddings, dtype=np.float32) + + if embeddings_np.shape[1] != self.embedding_model.embedding_dimension: + self.logger.error(f"Mismatch in embedding dimension: Model returned {embeddings_np.shape[1]}, expected {self.embedding_model.embedding_dimension}. Aborting storage.") + raise ValueError("Embedding dimension mismatch during embedding process.") + + self.logger.info("Saving embeddings to Chroma...") + chunk_ids = [c.id for c in chunk_objects] # Now safe to access .id because session is still open and committed + file_ids_for_chroma = [file_id] * len(chunk_ids) + + await self._run_sync( # Use _run_sync for the Chroma operation, as it's a sync call + self.chroma_manager.add_embeddings_sync, + file_ids_for_chroma, chunk_ids, embeddings_np, chunks # Pass original chunks texts for documents + ) + self.logger.info(f"Successfully saved {len(chunk_objects)} embeddings to Chroma.") + return chunk_objects + + except Exception as e: + session.rollback() # Rollback on any error + self.logger.error(f"Failed to process and store data for file_id {file_id}: {e}", exc_info=True) + raise # Re-raise the exception to propagate it + finally: + session.close() # Ensure the session is always closed \ No newline at end of file diff --git a/pkg/rag/knowledge/services/embedding_models.py b/pkg/rag/knowledge/services/embedding_models.py new file mode 100644 index 00000000..a6ce73ae --- /dev/null +++ b/pkg/rag/knowledge/services/embedding_models.py @@ -0,0 +1,223 @@ +# services/embedding_models.py + +import os +from typing import Dict, Any, List, Type, Optional +import logging +import aiohttp # Import aiohttp for asynchronous requests +import asyncio +from sentence_transformers import SentenceTransformer + +logger = logging.getLogger(__name__) + +# Base class for all embedding models +class BaseEmbeddingModel: + def __init__(self, model_name: str): + self.model_name = model_name + self._embedding_dimension = None + + async def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Asynchronously embeds a list of texts.""" + raise NotImplementedError + + async def embed_query(self, text: str) -> List[float]: + """Asynchronously embeds a single query text.""" + raise NotImplementedError + + @property + def embedding_dimension(self) -> int: + """Returns the embedding dimension of the model.""" + if self._embedding_dimension is None: + raise NotImplementedError("Embedding dimension not set for this model.") + return self._embedding_dimension + +class EmbeddingModelFactory: + @staticmethod + def create_model(model_type: str, model_name_key: str) -> BaseEmbeddingModel: + """ + Factory method to create an embedding model instance. + Currently only supports 'third_party_api' types. + """ + if model_name_key not in EMBEDDING_MODEL_CONFIGS: + raise ValueError(f"Embedding model configuration '{model_name_key}' not found in EMBEDDING_MODEL_CONFIGS.") + + config = EMBEDDING_MODEL_CONFIGS[model_name_key] + + if config['type'] == "third_party_api": + required_keys = ['api_endpoint', 'headers', 'payload_template', 'embedding_dimension'] + if not all(key in config for key in required_keys): + raise ValueError(f"Missing configuration keys for third_party_api model '{model_name_key}'. Required: {required_keys}") + + # Retrieve model_name from config if it differs from model_name_key + # Some APIs expect a specific 'model' value in the payload that might be different from the key + api_model_name = config.get('model_name', model_name_key) + + return ThirdPartyAPIEmbeddingModel( + model_name=api_model_name, # Use the model_name from config or the key + api_endpoint=config['api_endpoint'], + headers=config['headers'], + payload_template=config['payload_template'], + embedding_dimension=config['embedding_dimension'] + ) + +class SentenceTransformerEmbeddingModel(BaseEmbeddingModel): + def __init__(self, model_name: str): + super().__init__(model_name) + try: + # SentenceTransformer is inherently synchronous, but we'll wrap its calls + # in async methods. The actual computation will still block the event loop + # if not run in a separate thread/process, but this keeps the API consistent. + self.model = SentenceTransformer(model_name) + self._embedding_dimension = self.model.get_sentence_embedding_dimension() + logger.info(f"Initialized SentenceTransformer model '{model_name}' with dimension {self._embedding_dimension}") + except Exception as e: + logger.error(f"Failed to load SentenceTransformer model {model_name}: {e}") + raise + + async def embed_documents(self, texts: List[str]) -> List[List[float]]: + # For CPU-bound tasks like local model inference, consider running in a thread pool + # to prevent blocking the event loop for long operations. + # For simplicity here, we'll call it directly. + return self.model.encode(texts).tolist() + + async def embed_query(self, text: str) -> List[float]: + return self.model.encode(text).tolist() + + +class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel): + def __init__(self, model_name: str, api_endpoint: str, headers: Dict[str, str], payload_template: Dict[str, Any], embedding_dimension: int): + super().__init__(model_name) + self.api_endpoint = api_endpoint + self.headers = headers + self.payload_template = payload_template + self._embedding_dimension = embedding_dimension + self.session = None # aiohttp client session will be initialized on first use or in a context manager + logger.info(f"Initialized ThirdPartyAPIEmbeddingModel '{model_name}' for async calls to {api_endpoint} with dimension {embedding_dimension}") + + async def _get_session(self): + """Lazily create or return the aiohttp client session.""" + if self.session is None or self.session.closed: + self.session = aiohttp.ClientSession() + return self.session + + async def close_session(self): + """Explicitly close the aiohttp client session.""" + if self.session and not self.session.closed: + await self.session.close() + self.session = None + logger.info(f"Closed aiohttp session for model {self.model_name}") + + async def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Asynchronously embeds a list of texts using the third-party API.""" + session = await self._get_session() + embeddings = [] + tasks = [] + for text in texts: + payload = self.payload_template.copy() + if 'input' in payload: + payload['input'] = text + elif 'texts' in payload: + payload['texts'] = [text] + else: + raise ValueError("Payload template does not contain expected text input key.") + + tasks.append(self._make_api_request(session, payload)) + + results = await asyncio.gather(*tasks, return_exceptions=True) + + for i, res in enumerate(results): + if isinstance(res, Exception): + logger.error(f"Error embedding text '{texts[i][:50]}...': {res}") + # Depending on your error handling strategy, you might: + # - Append None or an empty list + # - Re-raise the exception to stop processing + # - Log and skip, then continue + embeddings.append([0.0] * self.embedding_dimension) # Append dummy embedding or handle failure + else: + embeddings.append(res) + + return embeddings + + async def _make_api_request(self, session: aiohttp.ClientSession, payload: Dict[str, Any]) -> List[float]: + """Helper to make an asynchronous API request and extract embedding.""" + try: + async with session.post(self.api_endpoint, headers=self.headers, json=payload) as response: + response.raise_for_status() # Raise an exception for HTTP errors (4xx, 5xx) + api_response = await response.json() + + # Adjust this based on your API's actual response structure + if "data" in api_response and len(api_response["data"]) > 0 and "embedding" in api_response["data"][0]: + embedding = api_response["data"][0]["embedding"] + if len(embedding) != self.embedding_dimension: + logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.") + return embedding + elif "embeddings" in api_response and isinstance(api_response["embeddings"], list) and api_response["embeddings"]: + embedding = api_response["embeddings"][0] + if len(embedding) != self.embedding_dimension: + logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.") + return embedding + else: + raise ValueError(f"Unexpected API response structure: {api_response}") + + except aiohttp.ClientError as e: + raise ConnectionError(f"API request failed: {e}") from e + except ValueError as e: + raise ValueError(f"Error processing API response: {e}") from e + + + async def embed_query(self, text: str) -> List[float]: + """Asynchronously embeds a single query text.""" + results = await self.embed_documents([text]) + if results: + return results[0] + return [] # Or raise an error if embedding a query must always succeed + +# --- Embedding Model Configuration --- +EMBEDDING_MODEL_CONFIGS: Dict[str, Dict[str, Any]] = { + "MiniLM": { # Example for a local Sentence Transformer model + "type": "sentence_transformer", + "model_name": "sentence-transformers/all-MiniLM-L6-v2" + }, + "bge-m3": { # Example for a third-party API model + "type": "third_party_api", + "model_name": "bge-m3", + "api_endpoint": "https://api.qhaigc.net/v1/embeddings", + "headers": { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.getenv('rag_api_key')}" + }, + "payload_template": { + "model": "bge-m3", + "input": "" + }, + "embedding_dimension": 1024 + }, + "OpenAI-Ada-002": { + "type": "third_party_api", + "model_name": "text-embedding-ada-002", + "api_endpoint": "https://api.openai.com/v1/embeddings", + "headers": { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}" # Ensure OPENAI_API_KEY is set + }, + "payload_template": { + "model": "text-embedding-ada-002", + "input": "" # Text will be injected here + }, + "embedding_dimension": 1536 + }, + "OpenAI-Embedding-3-Small": { + "type": "third_party_api", + "model_name": "text-embedding-3-small", + "api_endpoint": "https://api.openai.com/v1/embeddings", + "headers": { + "Content-Type": "application/json", + "Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}" + }, + "payload_template": { + "model": "text-embedding-3-small", + "input": "", + # "dimensions": 512 # Optional: uncomment if you want a specific output dimension + }, + "embedding_dimension": 1536 # Default max dimension for text-embedding-3-small + }, +} \ No newline at end of file diff --git a/pkg/rag/knowledge/services/parser.py b/pkg/rag/knowledge/services/parser.py new file mode 100644 index 00000000..5fa7d589 --- /dev/null +++ b/pkg/rag/knowledge/services/parser.py @@ -0,0 +1,288 @@ + +import PyPDF2 +from docx import Document +import pandas as pd +import csv +import chardet +from typing import Union, List, Callable, Any +import logging +import markdown +from bs4 import BeautifulSoup +import ebooklib +from ebooklib import epub +import re +import asyncio # Import asyncio for async operations +import os + +# Configure logging +logger = logging.getLogger(__name__) + +class FileParser: + """ + A robust file parser class to extract text content from various document formats. + It supports TXT, PDF, DOCX, XLSX, CSV, Markdown, HTML, and EPUB files. + All core file reading operations are designed to be run synchronously in a thread pool + to avoid blocking the asyncio event loop. + """ + def __init__(self): + + self.logger = logging.getLogger(self.__class__.__name__) + + async def _run_sync(self, sync_func: Callable, *args: Any, **kwargs: Any) -> Any: + """ + Runs a synchronous function in a separate thread to prevent blocking the event loop. + This is a general utility method for wrapping blocking I/O operations. + """ + try: + return await asyncio.to_thread(sync_func, *args, **kwargs) + except Exception as e: + self.logger.error(f"Error running synchronous function {sync_func.__name__}: {e}") + raise + + async def parse(self, file_path: str) -> Union[str, None]: + """ + Parses the file based on its extension and returns the extracted text content. + This is the main asynchronous entry point for parsing. + + Args: + file_path (str): The path to the file to be parsed. + + Returns: + Union[str, None]: The extracted text content as a single string, or None if parsing fails. + """ + if not file_path or not os.path.exists(file_path): + self.logger.error(f"Invalid file path provided: {file_path}") + return None + + file_extension = file_path.split('.')[-1].lower() + parser_method = getattr(self, f'_parse_{file_extension}', None) + + if parser_method is None: + self.logger.error(f"Unsupported file format: {file_extension} for file {file_path}") + return None + + try: + # Pass file_path to the specific parser methods + return await parser_method(file_path) + except Exception as e: + self.logger.error(f"Failed to parse {file_extension} file {file_path}: {e}") + return None + + # --- Helper for reading files with encoding detection --- + async def _read_file_content(self, file_path: str, mode: str = 'r') -> Union[str, bytes]: + """ + Reads a file with automatic encoding detection, ensuring the synchronous + file read operation runs in a separate thread. + """ + def _read_sync(): + with open(file_path, 'rb') as file: + raw_data = file.read() + detected = chardet.detect(raw_data) + encoding = detected['encoding'] or 'utf-8' + + if mode == 'r': + return raw_data.decode(encoding, errors='ignore') + return raw_data # For binary mode + + return await self._run_sync(_read_sync) + + # --- Specific Parser Methods --- + + async def _parse_txt(self, file_path: str) -> str: + """Parses a TXT file and returns its content.""" + self.logger.info(f"Parsing TXT file: {file_path}") + return await self._read_file_content(file_path, mode='r') + + async def _parse_pdf(self, file_path: str) -> str: + """Parses a PDF file and returns its text content.""" + self.logger.info(f"Parsing PDF file: {file_path}") + def _parse_pdf_sync(): + text_content = [] + with open(file_path, 'rb') as file: + pdf_reader = PyPDF2.PdfReader(file) + for page in pdf_reader.pages: + text = page.extract_text() + if text: + text_content.append(text) + return '\n'.join(text_content) + return await self._run_sync(_parse_pdf_sync) + + async def _parse_docx(self, file_path: str) -> str: + """Parses a DOCX file and returns its text content.""" + self.logger.info(f"Parsing DOCX file: {file_path}") + def _parse_docx_sync(): + doc = Document(file_path) + text_content = [paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()] + return '\n'.join(text_content) + return await self._run_sync(_parse_docx_sync) + + async def _parse_doc(self, file_path: str) -> str: + """Handles .doc files, explicitly stating lack of direct support.""" + self.logger.warning(f"Direct .doc parsing is not supported for {file_path}. Please convert to .docx first.") + raise NotImplementedError("Direct .doc parsing not supported. Please convert to .docx first.") + + async def _parse_xlsx(self, file_path: str) -> str: + """Parses an XLSX file, returning text from all sheets.""" + self.logger.info(f"Parsing XLSX file: {file_path}") + def _parse_xlsx_sync(): + excel_file = pd.ExcelFile(file_path) + all_sheet_content = [] + for sheet_name in excel_file.sheet_names: + df = pd.read_excel(file_path, sheet_name=sheet_name) + sheet_text = f"--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n" + all_sheet_content.append(sheet_text) + return '\n'.join(all_sheet_content) + return await self._run_sync(_parse_xlsx_sync) + + async def _parse_csv(self, file_path: str) -> str: + """Parses a CSV file and returns its content as a string.""" + self.logger.info(f"Parsing CSV file: {file_path}") + def _parse_csv_sync(): + # pd.read_csv can often detect encoding, but explicit detection is safer + raw_data = self._read_file_content(file_path, mode='rb') # Note: this will need to be await outside this sync function + # For simplicity, we'll let pandas handle encoding internally after a raw read. + # A more robust solution might pass encoding directly to pd.read_csv after detection. + detected = chardet.detect(open(file_path, 'rb').read()) + encoding = detected['encoding'] or 'utf-8' + df = pd.read_csv(file_path, encoding=encoding) + return df.to_string(index=False) + return await self._run_sync(_parse_csv_sync) + + async def _parse_markdown(self, file_path: str) -> str: + """Parses a Markdown file, converting it to structured plain text.""" + self.logger.info(f"Parsing Markdown file: {file_path}") + def _parse_markdown_sync(): + md_content = self._read_file_content(file_path, mode='r') # This is a synchronous call within a sync function + html_content = markdown.markdown( + md_content, + extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code'] + ) + soup = BeautifulSoup(html_content, 'html.parser') + text_parts = [] + for element in soup.children: + if element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']: + level = int(element.name[1]) + text_parts.append('#' * level + ' ' + element.get_text().strip()) + elif element.name == 'p': + text = element.get_text().strip() + if text: + text_parts.append(text) + elif element.name in ['ul', 'ol']: + for li in element.find_all('li'): + text_parts.append(f"* {li.get_text().strip()}") + elif element.name == 'pre': + code_block = element.get_text().strip() + if code_block: + text_parts.append(f"```\n{code_block}\n```") + elif element.name == 'table': + table_str = self._extract_table_to_markdown_sync(element) # Call sync helper + if table_str: + text_parts.append(table_str) + elif element.name: + text = element.get_text(separator=' ', strip=True) + if text: + text_parts.append(text) + cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts)) + return cleaned_text.strip() + return await self._run_sync(_parse_markdown_sync) + + async def _parse_html(self, file_path: str) -> str: + """Parses an HTML file, extracting structured plain text.""" + self.logger.info(f"Parsing HTML file: {file_path}") + def _parse_html_sync(): + html_content = self._read_file_content(file_path, mode='r') # Sync call within sync function + soup = BeautifulSoup(html_content, 'html.parser') + for script_or_style in soup(["script", "style"]): + script_or_style.decompose() + text_parts = [] + for element in soup.body.children if soup.body else soup.children: + if element.name in ['h1', 'h2', 'h3', 'h4', 'h5', 'h6']: + level = int(element.name[1]) + text_parts.append('#' * level + ' ' + element.get_text().strip()) + elif element.name == 'p': + text = element.get_text().strip() + if text: + text_parts.append(text) + elif element.name in ['ul', 'ol']: + for li in element.find_all('li'): + text = li.get_text().strip() + if text: + text_parts.append(f"* {text}") + elif element.name == 'table': + table_str = self._extract_table_to_markdown_sync(element) # Call sync helper + if table_str: + text_parts.append(table_str) + elif element.name: + text = element.get_text(separator=' ', strip=True) + if text: + text_parts.append(text) + cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts)) + return cleaned_text.strip() + return await self._run_sync(_parse_html_sync) + + async def _parse_epub(self, file_path: str) -> str: + """Parses an EPUB file, extracting metadata and content.""" + self.logger.info(f"Parsing EPUB file: {file_path}") + def _parse_epub_sync(): + book = epub.read_epub(file_path) + text_content = [] + title_meta = book.get_metadata('DC', 'title') + if title_meta: + text_content.append(f"Title: {title_meta[0][0]}") + creator_meta = book.get_metadata('DC', 'creator') + if creator_meta: + text_content.append(f"Author: {creator_meta[0][0]}") + date_meta = book.get_metadata('DC', 'date') + if date_meta: + text_content.append(f"Publish Date: {date_meta[0][0]}") + toc = book.get_toc() + if toc: + text_content.append("\n--- Table of Contents ---") + self._add_toc_items_sync(toc, text_content, level=0) # Call sync helper + text_content.append("--- End of Table of Contents ---\n") + for item in book.get_items(): + if item.get_type() == ebooklib.ITEM_DOCUMENT: + html_content = item.get_content().decode('utf-8', errors='ignore') + soup = BeautifulSoup(html_content, 'html.parser') + for junk in soup(["script", "style", "nav", "header", "footer"]): + junk.decompose() + text = soup.get_text(separator='\n', strip=True) + text = re.sub(r'\n\s*\n', '\n\n', text) + if text: + text_content.append(text) + return re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_content)).strip() + return await self._run_sync(_parse_epub_sync) + + def _add_toc_items_sync(self, toc_list: list, text_content: list, level: int): + """Recursively adds TOC items to text_content (synchronous helper).""" + indent = ' ' * level + for item in toc_list: + if isinstance(item, tuple): + chapter, subchapters = item + text_content.append(f"{indent}- {chapter.title}") + self._add_toc_items_sync(subchapters, text_content, level + 1) + else: + text_content.append(f"{indent}- {item.title}") + + def _extract_table_to_markdown_sync(self, table_element: BeautifulSoup) -> str: + """Helper to convert a BeautifulSoup table element into a Markdown table string (synchronous).""" + headers = [th.get_text().strip() for th in table_element.find_all('th')] + rows = [] + for tr in table_element.find_all('tr'): + cells = [td.get_text().strip() for td in tr.find_all('td')] + if cells: + rows.append(cells) + + if not headers and not rows: + return "" + + table_lines = [] + if headers: + table_lines.append(' | '.join(headers)) + table_lines.append(' | '.join(['---'] * len(headers))) + + for row_cells in rows: + padded_cells = row_cells + [''] * (len(headers) - len(row_cells)) if headers else row_cells + table_lines.append(' | '.join(padded_cells)) + + return '\n'.join(table_lines) \ No newline at end of file diff --git a/pkg/rag/knowledge/services/retriever.py b/pkg/rag/knowledge/services/retriever.py new file mode 100644 index 00000000..6da1c5d8 --- /dev/null +++ b/pkg/rag/knowledge/services/retriever.py @@ -0,0 +1,106 @@ +# services/retriever.py +import asyncio +import logging +import numpy as np # Make sure numpy is imported +from typing import List, Dict, Any +from sqlalchemy.orm import Session +from services.base_service import BaseService +from services.database import Chunk, SessionLocal +from services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory +from services.chroma_manager import ChromaIndexManager + +logger = logging.getLogger(__name__) + +class Retriever(BaseService): + def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager): + super().__init__() + self.logger = logging.getLogger(self.__class__.__name__) + self.model_type = model_type + self.model_name_key = model_name_key + self.chroma_manager = chroma_manager + + self.embedding_model: BaseEmbeddingModel = self._load_embedding_model() + + def _load_embedding_model(self) -> BaseEmbeddingModel: + self.logger.info(f"Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...") + try: + model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key) + self.logger.info(f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}") + return model + except Exception as e: + self.logger.error(f"Failed to load retriever embedding model '{self.model_name_key}': {e}") + raise + + async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]: + if not self.embedding_model: + raise RuntimeError("Retriever embedding model not loaded. Please check Retriever initialization.") + + self.logger.info(f"Retrieving for query: '{query}' with k={k} using {self.model_name_key}") + + query_embedding: List[float] = await self.embedding_model.embed_query(query) + query_embedding_np = np.array([query_embedding], dtype=np.float32) + + chroma_results = await self._run_sync( + self.chroma_manager.search_sync, + query_embedding_np, k + ) + + # 'ids' is always returned by ChromaDB, even if not explicitly in 'include' + matched_chroma_ids = chroma_results.get("ids", [[]])[0] + distances = chroma_results.get("distances", [[]])[0] + chroma_metadatas = chroma_results.get("metadatas", [[]])[0] + chroma_documents = chroma_results.get("documents", [[]])[0] + + if not matched_chroma_ids: + self.logger.info("No relevant chunks found in Chroma.") + return [] + + db_chunk_ids = [] + for metadata in chroma_metadatas: + if "chunk_id" in metadata: + db_chunk_ids.append(metadata["chunk_id"]) + else: + self.logger.warning(f"Metadata missing 'chunk_id': {metadata}. Skipping this entry.") + + if not db_chunk_ids: + self.logger.warning("No valid chunk_ids extracted from Chroma results metadata.") + return [] + + self.logger.info(f"Fetching {len(db_chunk_ids)} chunk details from relational database...") + chunks_from_db = await self._run_sync( + lambda cids: self._db_get_chunks_sync(SessionLocal(), cids), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync + db_chunk_ids + ) + + chunk_map = {chunk.id: chunk for chunk in chunks_from_db} + results_list: List[Dict[str, Any]] = [] + + for i, chroma_id in enumerate(matched_chroma_ids): + try: + # Ensure original_chunk_id is int for DB lookup + original_chunk_id = int(chroma_id.split('_')[-1]) + except (ValueError, IndexError): + self.logger.warning(f"Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.") + continue + + chunk_text_from_chroma = chroma_documents[i] + distance = float(distances[i]) + file_id_from_chroma = chroma_metadatas[i].get("file_id") + + chunk_from_db = chunk_map.get(original_chunk_id) + + results_list.append({ + "chunk_id": original_chunk_id, + "text": chunk_from_db.text if chunk_from_db else chunk_text_from_chroma, + "distance": distance, + "file_id": file_id_from_chroma + }) + + self.logger.info(f"Retrieved {len(results_list)} chunks.") + return results_list + + def _db_get_chunks_sync(self, session: Session, chunk_ids: List[int]) -> List[Chunk]: + self.logger.debug(f"Fetching {len(chunk_ids)} chunk details from database (sync).") + chunks = session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all() + session.close() + return chunks \ No newline at end of file diff --git a/pkg/rag/knowledge/utils/crawler.py b/pkg/rag/knowledge/utils/crawler.py new file mode 100644 index 00000000..e69de29b diff --git a/pyproject.toml b/pyproject.toml index 5e85bfb0..27a03a92 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,6 +49,17 @@ dependencies = [ "ruff>=0.11.9", "pre-commit>=4.2.0", "uv>=0.7.11", + "PyPDF2>=3.0.1", + "python-docx>=1.1.0", + "pandas>=2.2.2", + "chardet>=5.2.0", + "markdown>=3.6", + "beautifulsoup4>=4.12.3", + "ebooklib>=0.18", + "html2text>=2024.2.26", + "langchain>=0.2.0", + "chromadb>=0.4.24", + "sentence-transformers>=2.6.1", ] keywords = [ "bot",