feat(rag): make embedding and retrieving available

This commit is contained in:
Junyan Qin
2025-07-16 21:17:18 +08:00
parent f731115805
commit 2f2db4d445
20 changed files with 180 additions and 368 deletions

View File

@@ -86,10 +86,10 @@ class RouterGroup(abc.ABC):
try:
return await f(*args, **kwargs)
except Exception: # 自动 500
except Exception as e: # 自动 500
traceback.print_exc()
# return self.http_status(500, -2, str(e))
return self.http_status(500, -2, 'internal server error')
return self.http_status(500, -2, str(e))
new_f = handler_error
new_f.__name__ = (self.name + rule).replace('/', '__')

View File

@@ -34,8 +34,9 @@ class FilesRouterGroup(group.RouterGroup):
file_bytes = await asyncio.to_thread(file.stream.read)
extension = file.filename.split('.')[-1]
file_name = file.filename.split('.')[0]
file_key = str(uuid.uuid4()) + '.' + extension
file_key = file_name + '_' + str(uuid.uuid4())[:8] + '.' + extension
# save file to storage
await self.ap.storage_mgr.storage_provider.save(file_key, file_bytes)
return self.success(

View File

@@ -72,3 +72,13 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
async def delete_specific_file_in_kb(file_id: str, knowledge_base_uuid: str) -> str:
await self.ap.knowledge_service.delete_file(knowledge_base_uuid, file_id)
return self.success({})
@self.route(
'/<knowledge_base_uuid>/retrieve',
methods=['POST'],
)
async def retrieve_knowledge_base(knowledge_base_uuid: str) -> str:
json_data = await quart.request.json
query = json_data.get('query')
results = await self.ap.knowledge_service.retrieve_knowledge_base(knowledge_base_uuid, query)
return self.success(data={'results': results})

View File

@@ -67,6 +67,13 @@ class KnowledgeService:
raise Exception('Knowledge base not found')
return await runtime_kb.store_file(file_id)
async def retrieve_knowledge_base(self, kb_uuid: str, query: str) -> list[dict]:
"""检索知识库"""
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
if runtime_kb is None:
raise Exception('Knowledge base not found')
return [result.model_dump() for result in await runtime_kb.retrieve(query)]
async def get_files_by_knowledge_base(self, kb_uuid: str) -> list[dict]:
"""获取知识库文件"""
result = await self.ap.persistence_mgr.execute_async(

View File

@@ -186,6 +186,6 @@ class EmbeddingModelsService:
await runtime_embedding_model.requester.invoke_embedding(
model=runtime_embedding_model,
input_text='Hello, world!',
input_text=['Hello, world!'],
extra_args={},
)

View File

@@ -28,7 +28,7 @@ from ..storage import mgr as storagemgr
from ..utils import logcache
from . import taskmgr
from . import entities as core_entities
from ..rag.knowledge import mgr as rag_mgr
from ..rag.knowledge import kbmgr as rag_mgr
from ..vector import mgr as vectordb_mgr

View File

@@ -9,7 +9,7 @@ from ...command import cmdmgr
from ...provider.session import sessionmgr as llm_session_mgr
from ...provider.modelmgr import modelmgr as llm_model_mgr
from ...provider.tools import toolmgr as llm_tool_mgr
from ...rag.knowledge import mgr as rag_mgr
from ...rag.knowledge import kbmgr as rag_mgr
from ...platform import botmgr as im_mgr
from ...persistence import mgr as persistencemgr
from ...api.http.controller import main as http_controller
@@ -92,7 +92,7 @@ class BuildAppStage(stage.BootingStage):
ap.pipeline_mgr = pipeline_mgr
rag_mgr_inst = rag_mgr.RAGManager(ap)
await rag_mgr_inst.initialize_rag_system()
await rag_mgr_inst.initialize()
ap.rag_mgr = rag_mgr_inst
# 初始化向量数据库管理器

View File

View File

@@ -0,0 +1,13 @@
from __future__ import annotations
import pydantic
from typing import Any
class RetrieveResultEntry(pydantic.BaseModel):
id: str
metadata: dict[str, Any]
distance: float

View File

@@ -101,18 +101,18 @@ class ProviderAPIRequester(metaclass=abc.ABCMeta):
async def invoke_embedding(
self,
model: RuntimeEmbeddingModel,
input_text: str,
input_text: list[str],
extra_args: dict[str, typing.Any] = {},
) -> list[float]:
) -> list[list[float]]:
"""调用 Embedding API
Args:
query (core_entities.Query): 请求上下文
model (RuntimeEmbeddingModel): 使用的模型信息
input_text (str): 输入文本
input_text (list[str]): 输入文本
extra_args (dict[str, typing.Any], optional): 额外的参数. Defaults to {}.
Returns:
list[float]: 返回的 embedding 向量
list[list[float]]: 返回的 embedding 向量
"""
pass

View File

@@ -145,9 +145,9 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
async def invoke_embedding(
self,
model: requester.RuntimeEmbeddingModel,
input_text: str,
input_text: list[str],
extra_args: dict[str, typing.Any] = {},
) -> list[float]:
) -> list[list[float]]:
"""调用 Embedding API"""
self.client.api_key = model.token_mgr.get_token()
@@ -163,7 +163,8 @@ class OpenAIChatCompletions(requester.ProviderAPIRequester):
try:
resp = await self.client.embeddings.create(**args)
return resp.data[0].embedding
return [d.embedding for d in resp.data]
except asyncio.TimeoutError:
raise errors.RequesterError('请求超时')
except openai.BadRequestError as e:

View File

@@ -1,21 +1,15 @@
from __future__ import annotations
import os
import asyncio
import traceback
import uuid
from pkg.rag.knowledge.services.parser import FileParser
from pkg.rag.knowledge.services.chunker import Chunker
from pkg.rag.knowledge.services.database import (
KnowledgeBase,
File,
Chunk,
)
from .services import parser, chunker
from pkg.core import app
from pkg.rag.knowledge.services.embedder import Embedder
from pkg.rag.knowledge.services.retriever import Retriever
import sqlalchemy
from ...entity.persistence import rag as persistence_rag
from pkg.core import taskmgr
from ...entity.rag import retriever as retriever_entities
class RuntimeKnowledgeBase:
@@ -23,9 +17,9 @@ class RuntimeKnowledgeBase:
knowledge_base_entity: persistence_rag.KnowledgeBase
parser: FileParser
parser: parser.FileParser
chunker: Chunker
chunker: chunker.Chunker
embedder: Embedder
@@ -34,8 +28,8 @@ class RuntimeKnowledgeBase:
def __init__(self, ap: app.Application, knowledge_base_entity: persistence_rag.KnowledgeBase):
self.ap = ap
self.knowledge_base_entity = knowledge_base_entity
self.parser = FileParser(ap=self.ap)
self.chunker = Chunker(ap=self.ap)
self.parser = parser.FileParser(ap=self.ap)
self.chunker = chunker.Chunker(ap=self.ap)
self.embedder = Embedder(ap=self.ap)
self.retriever = Retriever(ap=self.ap)
# 传递kb_id给retriever
@@ -66,9 +60,16 @@ class RuntimeKnowledgeBase:
raise Exception(f'No chunks extracted from file {file.file_name}')
task_context.set_current_action('Embedding chunks')
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(
self.knowledge_base_entity.embedding_model_uuid
)
# embed chunks
await self.embedder.embed_and_store(
file_id=file.uuid, chunks=chunks_texts, embedding_model=self.knowledge_base_entity.embedding_model_uuid
kb_id=self.knowledge_base_entity.uuid,
file_id=file.uuid,
chunks=chunks_texts,
embedding_model=embedding_model,
)
# set file status to completed
@@ -79,7 +80,8 @@ class RuntimeKnowledgeBase:
)
except Exception as e:
self.ap.logger.error(f'Error storing file {file.file_id}: {e}')
self.ap.logger.error(f'Error storing file {file.uuid}: {e}')
traceback.print_exc()
# set file status to failed
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_rag.File)
@@ -89,7 +91,7 @@ class RuntimeKnowledgeBase:
raise
async def store_file(self, file_id: str) -> int:
async def store_file(self, file_id: str) -> str:
# pre checking
if not await self.ap.storage_mgr.storage_provider.exists(file_id):
raise Exception(f'File {file_id} not found')
@@ -97,22 +99,24 @@ class RuntimeKnowledgeBase:
file_uuid = str(uuid.uuid4())
kb_id = self.knowledge_base_entity.uuid
file_name = file_id
extension = os.path.splitext(file_id)[1].lstrip('.')
extension = file_name.split('.')[-1]
file = persistence_rag.File(
uuid=file_uuid,
kb_id=kb_id,
file_name=file_name,
extension=extension,
status='pending',
)
file_obj_data = {
'uuid': file_uuid,
'kb_id': kb_id,
'file_name': file_name,
'extension': extension,
'status': 'pending',
}
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(**file.to_dict()))
file_obj = persistence_rag.File(**file_obj_data)
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(file_obj_data))
# run background task asynchronously
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._store_file_task(file, task_context=ctx),
self._store_file_task(file_obj, task_context=ctx),
kind='knowledge-operation',
name=f'knowledge-store-file-{file_id}',
label=f'Store file {file_id}',
@@ -120,6 +124,12 @@ class RuntimeKnowledgeBase:
)
return wrapper.id
async def retrieve(self, query: str) -> list[retriever_entities.RetrieveResultEntry]:
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(
self.knowledge_base_entity.embedding_model_uuid
)
return await self.retriever.retrieve(self.knowledge_base_entity.uuid, query, embedding_model)
async def dispose(self):
pass
@@ -134,7 +144,7 @@ class RAGManager:
self.knowledge_bases = []
async def initialize(self):
pass
await self.load_knowledge_bases_from_db()
async def load_knowledge_bases_from_db(self):
self.ap.logger.info('Loading knowledge bases from db...')
@@ -183,57 +193,6 @@ class RAGManager:
self.knowledge_bases.remove(kb)
return
async def store_data(self, file_path: str, kb_id: str, file_type: str, file_id: str = None):
"""
Parses, chunks, embeds, and stores data from a given file into the RAG system.
Associates the file with a knowledge base using kb_id in the File table.
"""
self.ap.logger.info(f'Starting data storage process for file: {file_path}')
session = SessionLocal()
file_obj = None
try:
kb = session.query(KnowledgeBase).filter_by(id=kb_id).first()
if not kb:
self.ap.logger.info(f'Knowledge Base "{kb_id}" does not exist. ')
return
# get embedding model
embedding_model = await self.ap.model_mgr.get_embedding_model_by_uuid(kb.embedding_model_uuid)
file_name = os.path.basename(file_path)
text = await self.parser.parse(file_path)
if not text:
self.ap.logger.warning(f'No text extracted from file {file_path}. ')
return
chunks_texts = await self.chunker.chunk(text)
self.ap.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.")
await self.embedder.embed_and_store(file_id=file_id, chunks=chunks_texts, embedding_model=embedding_model)
self.ap.logger.info(f'Data storage process completed for file: {file_path}')
except Exception as e:
session.rollback()
self.ap.logger.error(f'Error in store_data for file {file_path}: {str(e)}', exc_info=True)
raise
finally:
if file_id:
file_obj = session.query(File).filter_by(id=file_id).first()
if file_obj:
file_obj.status = 1
session.close()
async def retrieve_data(self, query: str):
"""
Retrieves relevant data chunks based on a given query using the configured retriever.
"""
self.ap.logger.info(f"Starting data retrieval process for query: '{query}'")
try:
retrieved_chunks = await self.retriever.retrieve(query)
self.ap.logger.info(f'Successfully retrieved {len(retrieved_chunks)} chunks for query.')
return retrieved_chunks
except Exception as e:
self.ap.logger.error(f"Error in retrieve_data for query '{query}': {str(e)}", exc_info=True)
return []
async def delete_data_by_file_id(self, file_id: str):
"""
Deletes all data associated with a specific file ID, including its chunks and vectors,
@@ -243,7 +202,7 @@ class RAGManager:
session = SessionLocal()
try:
# delete vectors
await asyncio.to_thread(self.chroma_manager.delete_by_file_id_sync, file_id)
await asyncio.to_thread(self.ap.vector_db_mgr.vector_db.delete_by_file_id_sync, file_id)
self.ap.logger.info(f'Deleted embeddings from ChromaDB for file_id: {file_id}')
chunks_to_delete = session.query(Chunk).filter_by(file_id=file_id).all()
@@ -333,74 +292,18 @@ class RAGManager:
if session.is_active:
session.close()
async def get_file_content_by_file_id(self, file_id: str) -> str:
file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id)
# async def get_file_content_by_file_id(self, file_id: str) -> str:
# file_bytes = await self.ap.storage_mgr.storage_provider.load(file_id)
_, ext = os.path.splitext(file_id.lower())
ext = ext.lstrip('.')
# _, ext = os.path.splitext(file_id.lower())
# ext = ext.lstrip('.')
try:
text = file_bytes.decode('utf-8')
except UnicodeDecodeError:
return '[非文本文件或编码无法识别]'
# try:
# text = file_bytes.decode('utf-8')
# except UnicodeDecodeError:
# return '[非文本文件或编码无法识别]'
if ext in ['txt', 'md', 'csv', 'log', 'py', 'html']:
return text
else:
return f'[未知类型: .{ext}]'
async def relate_file_id_with_kb(self, knowledge_base_uuid: str, file_id: str) -> None:
"""
Associates a file with a knowledge base by updating the kb_id in the File table.
"""
self.ap.logger.info(f'Associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}')
session = SessionLocal()
try:
# 查询知识库是否存在
kb = session.query(KnowledgeBase).filter_by(id=knowledge_base_uuid).first()
if not kb:
self.ap.logger.error(f'Knowledge Base with UUID {knowledge_base_uuid} not found.')
return
if not await self.ap.storage_mgr.storage_provider.exists(file_id):
self.ap.logger.error(f'File with ID {file_id} does not exist.')
return
self.ap.logger.info(f'File with ID {file_id} exists, proceeding with association.')
# add new file record
file_to_update = File(
id=file_id,
kb_id=kb.id,
file_name=file_id,
path=os.path.join('data', 'storage', file_id),
file_type=os.path.splitext(file_id)[1].lstrip('.'),
status=0,
)
session.add(file_to_update)
session.commit()
self.ap.logger.info(
f'Successfully associated file ID {file_id} with knowledge base UUID {knowledge_base_uuid}'
)
except Exception as e:
session.rollback()
self.ap.logger.error(
f'Error associating file ID {file_id} with knowledge base UUID {knowledge_base_uuid}: {str(e)}',
exc_info=True,
)
finally:
# 进行文件解析
try:
await self.store_data(
file_path=os.path.join('data', 'storage', file_id),
kb_id=knowledge_base_uuid,
file_type=os.path.splitext(file_id)[1].lstrip('.'),
file_id=file_id,
)
except Exception:
# 如果存储数据时出错,更新文件状态为失败
file_obj = session.query(File).filter_by(id=file_id).first()
if file_obj:
file_obj.status = 2
session.commit()
self.ap.logger.error(f'Error storing data for file ID {file_id}', exc_info=True)
session.close()
# if ext in ['txt', 'md', 'csv', 'log', 'py', 'html']:
# return text
# else:
# return f'[未知类型: .{ext}]'

View File

@@ -1,26 +1,15 @@
# 封装异步操作
import asyncio
import logging
from pkg.rag.knowledge.services.database import SessionLocal
class BaseService:
def __init__(self):
self.logger = logging.getLogger(self.__class__.__name__)
self.db_session_factory = SessionLocal
pass
async def _run_sync(self, func, *args, **kwargs):
"""
在单独的线程中运行同步函数。
如果第一个参数是 session则在 to_thread 中获取新的 session。
"""
if getattr(func, '__name__', '').startswith('_db_'):
session = await asyncio.to_thread(self.db_session_factory)
try:
result = await asyncio.to_thread(func, session, *args, **kwargs)
return result
finally:
session.close()
else:
# 否则,直接运行同步函数
return await asyncio.to_thread(func, *args, **kwargs)
return await asyncio.to_thread(func, *args, **kwargs)

View File

@@ -1,24 +1,21 @@
# services/chunker.py
import logging
from __future__ import annotations
from typing import List
from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync
from pkg.rag.knowledge.services import base_service
from pkg.core import app
logger = logging.getLogger(__name__)
class Chunker(BaseService):
class Chunker(base_service.BaseService):
"""
A class for splitting long texts into smaller, overlapping chunks.
"""
def __init__(self, ap: app.Application, chunk_size: int = 500, chunk_overlap: int = 50):
super().__init__(ap) # Initialize BaseService
self.ap = ap
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
if self.chunk_overlap >= self.chunk_size:
self.logger.warning(
self.ap.logger.warning(
'Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.'
)

View File

@@ -1,23 +0,0 @@
# 全部迁移过去
from pkg.entity.persistence.rag import (
create_db_and_tables,
SessionLocal,
Base,
engine,
KnowledgeBase,
File,
Chunk,
Vector,
)
__all__ = [
"create_db_and_tables",
"SessionLocal",
"Base",
"engine",
"KnowledgeBase",
"File",
"Chunk",
"Vector",
]

View File

@@ -1,12 +1,11 @@
from __future__ import annotations
import asyncio
import numpy as np
import uuid
from typing import List
from sqlalchemy.orm import Session
from pkg.rag.knowledge.services.base_service import BaseService
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
from ....entity.persistence import rag as persistence_rag
from ....core import app
from ....provider.modelmgr.requester import RuntimeEmbeddingModel
import sqlalchemy
class Embedder(BaseService):
@@ -14,74 +13,41 @@ class Embedder(BaseService):
super().__init__()
self.ap = ap
def _db_save_chunks_sync(self, session: Session, file_id: int, chunks_texts: List[str]):
"""
Saves chunks to the relational database and returns the created Chunk objects.
This function assumes it's called within a context where the session
will be committed/rolled back and closed by the caller.
"""
self.ap.logger.debug(f'Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).')
chunk_objects = []
for text in chunks_texts:
chunk = Chunk(file_id=file_id, text=text)
session.add(chunk)
chunk_objects.append(chunk)
session.flush() # This populates the .id attribute for each new chunk object
self.ap.logger.debug(f'Successfully added {len(chunk_objects)} chunk entries to DB.')
return chunk_objects
async def embed_and_store(
self, file_id: int, chunks: List[str], embedding_model: RuntimeEmbeddingModel
) -> List[Chunk]:
session = SessionLocal() # Start a session that will live for the whole operation
chunk_objects = []
try:
# 1. Save chunks to the relational database first to get their IDs
# We call _db_save_chunks_sync directly without _run_sync's session management
# because we manage the session here across multiple async calls.
chunk_objects = await asyncio.to_thread(self._db_save_chunks_sync, session, file_id, chunks)
session.commit() # Commit chunks to make their IDs permanent and accessible
self, kb_id: str, file_id: str, chunks: List[str], embedding_model: RuntimeEmbeddingModel
) -> list[persistence_rag.Chunk]:
# save chunk to db
chunk_entities: list[persistence_rag.Chunk] = []
chunk_ids: list[str] = []
if not chunk_objects:
self.ap.logger.warning(
f'No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.'
)
return []
for chunk_text in chunks:
chunk_uuid = str(uuid.uuid4())
chunk_ids.append(chunk_uuid)
chunk_entity = persistence_rag.Chunk(uuid=chunk_uuid, file_id=file_id, text=chunk_text)
chunk_entities.append(chunk_entity)
# get the embeddings for the chunks
embeddings: list[list[float]] = []
chunk_dicts = [
self.ap.persistence_mgr.serialize_model(persistence_rag.Chunk, chunk) for chunk in chunk_entities
]
for chunk in chunks:
result = await embedding_model.requester.invoke_embedding(
model=embedding_model,
input_text=chunk,
)
embeddings.append(result)
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.Chunk).values(chunk_dicts))
embeddings_np = np.array(embeddings, dtype=np.float32)
# get embeddings
embeddings_list: list[list[float]] = await embedding_model.requester.invoke_embedding(
model=embedding_model,
input_text=chunks,
extra_args={}, # TODO: add extra args
)
chunk_ids = [c.id for c in chunk_objects]
# collection名用kb_idfile对象有kb_id字段
kb_id = session.query(Chunk).filter_by(id=chunk_ids[0]).first().file.kb_id if chunk_ids else None
if not kb_id:
self.ap.logger.warning('无法获取kb_id向量存储失败')
return chunk_objects
chroma_ids = [f'{file_id}_{cid}' for cid in chunk_ids]
metadatas = [{'file_id': file_id, 'chunk_id': cid} for cid in chunk_ids]
await self._run_sync(
self.ap.vector_db_mgr.vector_db.add_embeddings,
kb_id,
chroma_ids,
embeddings_np,
metadatas,
chunks,
)
self.ap.logger.info(f'Successfully saved {len(chunk_objects)} embeddings to VectorDB.')
return chunk_objects
# save embeddings to vdb
await self._run_sync(
self.ap.vector_db_mgr.vector_db.add_embeddings,
kb_id,
chunk_ids,
embeddings_list,
chunk_dicts,
)
except Exception as e:
session.rollback() # Rollback on any error
self.ap.logger.error(f'Failed to process and store data for file_id {file_id}: {e}', exc_info=True)
raise # Re-raise the exception to propagate it
finally:
session.close() # Ensure the session is always closed
self.ap.logger.info(f'Successfully saved {len(chunk_entities)} embeddings to Knowledge Base.')
return chunk_entities

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import PyPDF2
import io
from docx import Document

View File

@@ -1,99 +1,46 @@
from __future__ import annotations
import logging
import numpy as np # Make sure numpy is imported
from typing import List, Dict, Any
from sqlalchemy.orm import Session
from pkg.rag.knowledge.services.base_service import BaseService
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
from pkg.vector.vdb import VectorDatabase
from . import base_service
from ....core import app
logger = logging.getLogger(__name__)
from ....provider.modelmgr.requester import RuntimeEmbeddingModel
from ....entity.rag import retriever as retriever_entities
class Retriever(BaseService):
class Retriever(base_service.BaseService):
def __init__(self, ap: app.Application):
super().__init__()
self.logger = logging.getLogger(self.__class__.__name__)
self.ap = ap
self.vector_db: VectorDatabase = ap.vector_db_mgr.vector_db
async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
if not self.embedding_model:
raise RuntimeError('Retriever embedding model not loaded. Please check Retriever initialization.')
async def retrieve(
self, kb_id: str, query: str, embedding_model: RuntimeEmbeddingModel, k: int = 5
) -> list[retriever_entities.RetrieveResultEntry]:
self.ap.logger.info(f"Retrieving for query: '{query}' with k={k} using {embedding_model.model_entity.uuid}")
self.logger.info(f"Retrieving for query: '{query}' with k={k} using {self.model_name_key}")
query_embedding: list[float] = await embedding_model.requester.invoke_embedding(
model=embedding_model,
input_text=[query],
extra_args={}, # TODO: add extra args
)
query_embedding: List[float] = await self.embedding_model.embed_query(query)
query_embedding_np = np.array([query_embedding], dtype=np.float32)
# collection名用kb_id假设retriever有kb_id属性或通过ap传递
kb_id = getattr(self, 'kb_id', None)
if not kb_id:
self.logger.warning('无法获取kb_id向量检索失败')
return []
chroma_results = await self._run_sync(self.vector_db.search, kb_id, query_embedding_np, k)
chroma_results = await self._run_sync(self.ap.vector_db_mgr.vector_db.search, kb_id, query_embedding[0], k)
# 'ids' is always returned by ChromaDB, even if not explicitly in 'include'
matched_chroma_ids = chroma_results.get('ids', [[]])[0]
distances = chroma_results.get('distances', [[]])[0]
chroma_metadatas = chroma_results.get('metadatas', [[]])[0]
chroma_documents = chroma_results.get('documents', [[]])[0]
if not matched_chroma_ids:
self.logger.info('No relevant chunks found in Chroma.')
self.ap.logger.info('No relevant chunks found in Chroma.')
return []
db_chunk_ids = []
for metadata in chroma_metadatas:
if 'chunk_id' in metadata:
db_chunk_ids.append(metadata['chunk_id'])
else:
self.logger.warning(f"Metadata missing 'chunk_id': {metadata}. Skipping this entry.")
result: list[retriever_entities.RetrieveResultEntry] = []
if not db_chunk_ids:
self.logger.warning('No valid chunk_ids extracted from Chroma results metadata.')
return []
self.logger.info(f'Fetching {len(db_chunk_ids)} chunk details from relational database...')
chunks_from_db = await self._run_sync(
lambda cids: self._db_get_chunks_sync(
SessionLocal(), cids
), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync
db_chunk_ids,
)
chunk_map = {chunk.id: chunk for chunk in chunks_from_db}
results_list: List[Dict[str, Any]] = []
for i, chroma_id in enumerate(matched_chroma_ids):
try:
# Ensure original_chunk_id is int for DB lookup
original_chunk_id = int(chroma_id.split('_')[-1])
except (ValueError, IndexError):
self.logger.warning(f'Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.')
continue
chunk_text_from_chroma = chroma_documents[i]
distance = float(distances[i])
file_id_from_chroma = chroma_metadatas[i].get('file_id')
chunk_from_db = chunk_map.get(original_chunk_id)
results_list.append(
{
'chunk_id': original_chunk_id,
'text': chunk_from_db.text if chunk_from_db else chunk_text_from_chroma,
'distance': distance,
'file_id': file_id_from_chroma,
}
for i, id in enumerate(matched_chroma_ids):
entry = retriever_entities.RetrieveResultEntry(
id=id,
metadata=chroma_metadatas[i],
distance=distances[i],
)
result.append(entry)
self.logger.info(f'Retrieved {len(results_list)} chunks.')
return results_list
def _db_get_chunks_sync(self, session: Session, chunk_ids: List[int]) -> List[Chunk]:
self.logger.debug(f'Fetching {len(chunk_ids)} chunk details from database (sync).')
chunks = session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all()
session.close()
return chunks
return result

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
import abc
from typing import Any, List, Dict
from typing import Any, Dict
import numpy as np
@@ -9,10 +9,10 @@ class VectorDatabase(abc.ABC):
def add_embeddings(
self,
collection: str,
ids: List[str],
embeddings: np.ndarray,
metadatas: List[Dict[str, Any]],
documents: List[str],
ids: list[str],
embeddings_list: list[list[float]],
metadatas: list[dict[str, Any]],
documents: list[str],
) -> None:
"""向指定 collection 添加向量数据。"""
pass

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
import numpy as np
from typing import Any, List, Dict
import chromadb
from typing import Any
from chromadb import PersistentClient
from pkg.vector.vdb import VectorDatabase
from pkg.core import app
@@ -12,7 +12,7 @@ class ChromaVectorDatabase(VectorDatabase):
self.client = PersistentClient(path=base_path)
self._collections = {}
def get_or_create_collection(self, collection: str):
def get_or_create_collection(self, collection: str) -> chromadb.Collection:
if collection not in self._collections:
self._collections[collection] = self.client.get_or_create_collection(name=collection)
self.ap.logger.info(f"Chroma collection '{collection}' accessed/created.")
@@ -21,26 +21,25 @@ class ChromaVectorDatabase(VectorDatabase):
def add_embeddings(
self,
collection: str,
ids: List[str],
embeddings: np.ndarray,
metadatas: List[Dict[str, Any]],
documents: List[str],
ids: list[str],
embeddings_list: list[list[float]],
metadatas: list[dict[str, Any]],
) -> None:
col = self.get_or_create_collection(collection)
col.add(embeddings=embeddings.tolist(), ids=ids, metadatas=metadatas, documents=documents)
col.add(embeddings=embeddings_list, ids=ids, metadatas=metadatas)
self.ap.logger.info(f"Added {len(ids)} embeddings to Chroma collection '{collection}'.")
def search(self, collection: str, query_embedding: np.ndarray, k: int = 5) -> Dict[str, Any]:
def search(self, collection: str, query_embedding: list[float], k: int = 5) -> dict[str, Any]:
col = self.get_or_create_collection(collection)
results = col.query(
query_embeddings=query_embedding.tolist(),
query_embeddings=query_embedding,
n_results=k,
include=['metadatas', 'distances', 'documents'],
)
self.ap.logger.debug(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.")
self.ap.logger.info(f"Chroma search in '{collection}' returned {len(results.get('ids', [[]])[0])} results.")
return results
def delete_by_metadata(self, collection: str, where: Dict[str, Any]) -> None:
def delete_by_metadata(self, collection: str, where: dict[str, Any]) -> None:
col = self.get_or_create_collection(collection)
col.delete(where=where)
self.ap.logger.info(f"Deleted embeddings from Chroma collection '{collection}' with filter: {where}")