fix: bugs

This commit is contained in:
Junyan Qin
2025-07-11 16:38:08 +08:00
parent 367d04d0f0
commit 9ba1ad5bd3
3 changed files with 15 additions and 58 deletions

View File

@@ -6,11 +6,7 @@ import asyncio
import uuid
from pkg.rag.knowledge.services.parser import FileParser
from pkg.rag.knowledge.services.chunker import Chunker
from pkg.rag.knowledge.services.embedder import Embedder
from pkg.rag.knowledge.services.retriever import Retriever
from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk
from pkg.rag.knowledge.services.embedding_models import EmbeddingModelFactory
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
from pkg.core import app
@@ -20,8 +16,6 @@ class RAGManager:
def __init__(self, ap: app.Application, logger: logging.Logger = None):
self.ap = ap
self.logger = logger or logging.getLogger(__name__)
self.embedding_model_type = None
self.embedding_model_name = None
self.chroma_manager = None
self.parser = FileParser()
self.chunker = Chunker()
@@ -32,50 +26,13 @@ class RAGManager:
"""Initializes the RAG system by creating database tables."""
await asyncio.to_thread(create_db_and_tables)
async def create_specific_model(self, embedding_model_type: str, embedding_model_name: str):
"""
Creates and configures the specific embedding model and ChromaDB manager.
This must be called before performing embedding or retrieval operations.
"""
self.embedding_model_type = embedding_model_type
self.embedding_model_name = embedding_model_name
try:
model = EmbeddingModelFactory.create_model(
model_type=self.embedding_model_type, model_name_key=self.embedding_model_name
)
self.logger.info(
f"Configured embedding model '{self.embedding_model_name}' has dimension: {model.embedding_dimension}"
)
except Exception as e:
self.logger.critical(
f"Failed to get dimension for configured embedding model '{self.embedding_model_name}': {e}"
)
raise RuntimeError('Failed to initialize RAG_Manager due to embedding model issues.')
self.chroma_manager = ChromaIndexManager(
collection_name=f'rag_collection_{self.embedding_model_name.replace("-", "_")}'
)
self.embedder = Embedder(
model_type=self.embedding_model_type,
model_name_key=self.embedding_model_name,
chroma_manager=self.chroma_manager,
)
self.retriever = Retriever(
model_type=self.embedding_model_type,
model_name_key=self.embedding_model_name,
chroma_manager=self.chroma_manager,
)
async def create_knowledge_base(self, kb_name: str, kb_description: str, embedding_model: str = '', top_k: int = 5):
"""
Creates a new knowledge base if it doesn't already exist.
"""
try:
if not self.embedding_model_type or not kb_name:
raise ValueError(
'Embedding model type and knowledge base name must be set before creating a knowledge base.'
)
if not kb_name:
raise ValueError('Knowledge base name must be set while creating.')
def _create_kb_sync():
session = SessionLocal()