mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 03:15:06 +08:00
feat: add embedder
This commit is contained in:
@@ -27,6 +27,7 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
json_data.get('name'),
|
||||
json_data.get('description'),
|
||||
json_data.get('embedding_model_uuid'),
|
||||
json_data.get('top_k',5),
|
||||
)
|
||||
return self.success(data={'uuid': knowledge_base_uuid})
|
||||
|
||||
|
||||
@@ -100,7 +100,6 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
|
||||
|
||||
async def invoke_embedding(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
model: RuntimeEmbeddingModel,
|
||||
input_text: str,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
|
||||
@@ -144,7 +144,6 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
|
||||
|
||||
async def invoke_embedding(
|
||||
self,
|
||||
query: core_entities.Query,
|
||||
model: requester.RuntimeEmbeddingModel,
|
||||
input_text: str,
|
||||
extra_args: dict[str, typing.Any] = {},
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# rag_manager.py
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import asyncio
|
||||
@@ -10,6 +10,8 @@ from pkg.core import app
|
||||
from pkg.rag.knowledge.services.embedder import Embedder
|
||||
from pkg.rag.knowledge.services.retriever import Retriever
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
|
||||
from ...entity.persistence import model as persistence_model
|
||||
import sqlalchemy
|
||||
|
||||
|
||||
class RAGManager:
|
||||
@@ -20,9 +22,8 @@ class RAGManager:
|
||||
self.chroma_manager = ChromaIndexManager()
|
||||
self.parser = FileParser()
|
||||
self.chunker = Chunker()
|
||||
# Initialize Embedder with targeted model type and name
|
||||
self.embedder = Embedder(model_type='third_party_api', model_name_key='bge-m3', chroma_manager=self.chroma_manager)
|
||||
self.retriever = Retriever(model_type='third_party_api', model_name_key='bge-m3', chroma_manager=self.chroma_manager)
|
||||
self.embedder = Embedder(ap=self.ap, chroma_manager=self.chroma_manager)
|
||||
self.retriever = Retriever(ap=self.ap, chroma_manager=self.chroma_manager)
|
||||
|
||||
async def initialize_rag_system(self):
|
||||
"""Initializes the RAG system by creating database tables."""
|
||||
@@ -55,6 +56,7 @@ class RAGManager:
|
||||
session.commit()
|
||||
session.refresh(new_kb)
|
||||
self.ap.logger.info(f"Knowledge Base '{kb_name}' created.")
|
||||
print(embedding_model_uuid)
|
||||
return new_kb.id
|
||||
else:
|
||||
self.ap.logger.info(f"Knowledge Base '{kb_name}' already exists.")
|
||||
@@ -158,10 +160,9 @@ class RAGManager:
|
||||
kb = session.query(KnowledgeBase).filter_by(id=kb_id).first()
|
||||
if not kb:
|
||||
self.ap.logger.info(f'Knowledge Base "{kb_id}" does not exist. ')
|
||||
self.ap.logger.info(f'Created Knowledge Base with ID: {kb_id}')
|
||||
else:
|
||||
self.ap.logger.info(f"Knowledge Base '{kb_id}' already exists.")
|
||||
|
||||
return
|
||||
# get embedding model
|
||||
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(kb.embedding_model_uuid)
|
||||
file_name = os.path.basename(file_path)
|
||||
text = await self.parser.parse(file_path)
|
||||
if not text:
|
||||
@@ -172,7 +173,7 @@ class RAGManager:
|
||||
|
||||
chunks_texts = await self.chunker.chunk(text)
|
||||
self.ap.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.")
|
||||
await self.embedder.embed_and_store(file_id=file_id, chunks=chunks_texts)
|
||||
await self.embedder.embed_and_store(file_id=file_id, chunks=chunks_texts, embedding_model=embedding_model)
|
||||
self.ap.logger.info(f'Data storage process completed for file: {file_path}')
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# services/embedder.py
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np
|
||||
@@ -6,30 +6,23 @@ from typing import List
|
||||
from sqlalchemy.orm import Session
|
||||
from pkg.rag.knowledge.services.base_service import BaseService
|
||||
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
|
||||
from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager # Import the manager
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from ....core import app
|
||||
from ....entity.persistence import model as persistence_model
|
||||
import sqlalchemy
|
||||
from ....provider.modelmgr.requester import RuntimeEmbeddingModel
|
||||
|
||||
|
||||
base = declarative_base()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Embedder(BaseService):
|
||||
def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager = None):
|
||||
def __init__(self, ap: app.Application, chroma_manager: ChromaIndexManager = None) -> None:
|
||||
super().__init__()
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.model_type = model_type
|
||||
self.model_name_key = model_name_key
|
||||
self.chroma_manager = chroma_manager # Dependency Injection
|
||||
|
||||
self.embedding_model: BaseEmbeddingModel = self._load_embedding_model()
|
||||
|
||||
def _load_embedding_model(self) -> BaseEmbeddingModel:
|
||||
self.logger.info(f"Loading embedding model: type={self.model_type}, name_key={self.model_name_key}...")
|
||||
try:
|
||||
model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key)
|
||||
self.logger.info(f"Embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}")
|
||||
return model
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load embedding model '{self.model_name_key}': {e}")
|
||||
raise
|
||||
self.chroma_manager = chroma_manager
|
||||
self.ap = ap
|
||||
|
||||
def _db_save_chunks_sync(self, session: Session, file_id: int, chunks_texts: List[str]):
|
||||
"""
|
||||
@@ -47,12 +40,10 @@ class Embedder(BaseService):
|
||||
self.logger.debug(f"Successfully added {len(chunk_objects)} chunk entries to DB.")
|
||||
return chunk_objects
|
||||
|
||||
async def embed_and_store(self, file_id: int, chunks: List[str]):
|
||||
if not self.embedding_model:
|
||||
async def embed_and_store(self, file_id: int, chunks: List[str], embedding_model: RuntimeEmbeddingModel) -> List[Chunk]:
|
||||
if not embedding_model:
|
||||
raise RuntimeError("Embedding model not loaded. Please check Embedder initialization.")
|
||||
|
||||
self.logger.info(f"Embedding {len(chunks)} chunks for file_id: {file_id} using {self.model_name_key}...")
|
||||
|
||||
session = SessionLocal() # Start a session that will live for the whole operation
|
||||
chunk_objects = []
|
||||
try:
|
||||
@@ -65,17 +56,23 @@ class Embedder(BaseService):
|
||||
if not chunk_objects:
|
||||
self.logger.warning(f"No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.")
|
||||
return []
|
||||
|
||||
# 2. Generate embeddings
|
||||
embeddings: List[List[float]] = await self.embedding_model.embed_documents(chunks)
|
||||
|
||||
# get the embeddings for the chunks
|
||||
embeddings = []
|
||||
i = 0
|
||||
while i <len(chunks):
|
||||
chunk = chunks[i]
|
||||
result = await embedding_model.requester.invoke_embedding(
|
||||
model=embedding_model,
|
||||
input_text=chunk,
|
||||
)
|
||||
embeddings.append(result)
|
||||
i += 1
|
||||
|
||||
embeddings_np = np.array(embeddings, dtype=np.float32)
|
||||
|
||||
if embeddings_np.shape[1] != self.embedding_model.embedding_dimension:
|
||||
self.logger.error(f"Mismatch in embedding dimension: Model returned {embeddings_np.shape[1]}, expected {self.embedding_model.embedding_dimension}. Aborting storage.")
|
||||
raise ValueError("Embedding dimension mismatch during embedding process.")
|
||||
|
||||
self.logger.info("Saving embeddings to Chroma...")
|
||||
chunk_ids = [c.id for c in chunk_objects] # Now safe to access .id because session is still open and committed
|
||||
chunk_ids = [c.id for c in chunk_objects]
|
||||
file_ids_for_chroma = [file_id] * len(chunk_ids)
|
||||
|
||||
await self._run_sync( # Use _run_sync for the Chroma operation, as it's a sync call
|
||||
|
||||
@@ -1,39 +1,22 @@
|
||||
# services/retriever.py
|
||||
from __future__ import annotations
|
||||
import logging
|
||||
import numpy as np # Make sure numpy is imported
|
||||
from typing import List, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from pkg.rag.knowledge.services.base_service import BaseService
|
||||
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
|
||||
from pkg.rag.knowledge.services.embedding_models import BaseEmbeddingModel, EmbeddingModelFactory
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
|
||||
from ....core import app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Retriever(BaseService):
|
||||
def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager):
|
||||
def __init__(self, ap:app.Application, chroma_manager: ChromaIndexManager):
|
||||
super().__init__()
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
self.model_type = model_type
|
||||
self.model_name_key = model_name_key
|
||||
self.chroma_manager = chroma_manager
|
||||
|
||||
self.embedding_model: BaseEmbeddingModel = self._load_embedding_model()
|
||||
|
||||
def _load_embedding_model(self) -> BaseEmbeddingModel:
|
||||
self.logger.info(
|
||||
f'Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...'
|
||||
)
|
||||
try:
|
||||
model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key)
|
||||
self.logger.info(
|
||||
f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}"
|
||||
)
|
||||
return model
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load retriever embedding model '{self.model_name_key}': {e}")
|
||||
raise
|
||||
self.ap = ap
|
||||
|
||||
async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
|
||||
if not self.embedding_model:
|
||||
|
||||
Reference in New Issue
Block a user