chore: stash

This commit is contained in:
Junyan Qin
2025-07-15 22:09:10 +08:00
parent 199164fc4b
commit 67bc065ccd
15 changed files with 508 additions and 338 deletions

View File

@@ -14,7 +14,7 @@ preregistered_groups: list[type[RouterGroup]] = []
"""RouterGroup 的预注册列表"""
def group_class(name: str, path: str) -> None:
def group_class(name: str, path: str) -> typing.Callable[[typing.Type[RouterGroup]], typing.Type[RouterGroup]]:
"""注册一个 RouterGroup"""
def decorator(cls: typing.Type[RouterGroup]) -> typing.Type[RouterGroup]:
@@ -120,6 +120,6 @@ class RouterGroup(abc.ABC):
}
)
def http_status(self, status: int, code: int, msg: str) -> quart.Response:
def http_status(self, status: int, code: int, msg: str) -> typing.Tuple[quart.Response, int]:
"""返回一个指定状态码的响应"""
return self.fail(code, msg), status
return (self.fail(code, msg), status)

View File

@@ -5,77 +5,49 @@ from ... import group
@group.group_class('knowledge_base', '/api/v1/knowledge/bases')
class KnowledgeBaseRouterGroup(group.RouterGroup):
async def initialize(self) -> None:
@self.route('', methods=['POST', 'GET'], endpoint='handle_knowledge_bases')
async def handle_knowledge_bases() -> str:
@self.route('', methods=['POST', 'GET'])
async def handle_knowledge_bases() -> quart.Response:
if quart.request.method == 'GET':
knowledge_bases = await self.ap.knowledge_base_service.get_all_knowledge_bases()
bases_list = [
{
'uuid': kb.id,
'name': kb.name,
'description': kb.description,
'embedding_model_uuid': kb.embedding_model_uuid,
'top_k': kb.top_k,
}
for kb in knowledge_bases
]
return self.success(data={'bases': bases_list})
knowledge_bases = await self.ap.knowledge_service.get_knowledge_bases()
return self.success(data={'bases': knowledge_bases})
elif quart.request.method == 'POST':
json_data = await quart.request.json
knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base(
json_data.get('name'),
json_data.get('description'),
json_data.get('embedding_model_uuid'),
json_data.get('top_k',5),
)
knowledge_base_uuid = await self.ap.knowledge_service.create_knowledge_base(json_data)
return self.success(data={'uuid': knowledge_base_uuid})
return self.http_status(405, -1, 'Method not allowed')
@self.route(
'/<knowledge_base_uuid>',
methods=['GET', 'DELETE'],
endpoint='handle_specific_knowledge_base',
)
async def handle_specific_knowledge_base(knowledge_base_uuid: str) -> str:
async def handle_specific_knowledge_base(knowledge_base_uuid: str) -> quart.Response:
if quart.request.method == 'GET':
knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(knowledge_base_uuid)
knowledge_base = await self.ap.knowledge_service.get_knowledge_base(knowledge_base_uuid)
if knowledge_base is None:
return self.http_status(404, -1, 'knowledge base not found')
return self.success(
data={
'base': {
'name': knowledge_base.name,
'description': knowledge_base.description,
'uuid': knowledge_base.id,
'embedding_model_uuid': knowledge_base.embedding_model_uuid,
'top_k': knowledge_base.top_k,
},
'base': knowledge_base,
}
)
elif quart.request.method == 'DELETE':
await self.ap.knowledge_base_service.delete_kb_by_id(knowledge_base_uuid)
await self.ap.knowledge_service.delete_knowledge_base(knowledge_base_uuid)
return self.success({})
@self.route(
'/<knowledge_base_uuid>/files',
methods=['GET', 'POST'],
endpoint='get_knowledge_base_files',
)
async def get_knowledge_base_files(knowledge_base_uuid: str) -> str:
if quart.request.method == 'GET':
files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(knowledge_base_uuid)
files = await self.ap.knowledge_service.get_files_by_knowledge_base(knowledge_base_uuid)
return self.success(
data={
'files': [
{
'id': file.id,
'file_name': file.file_name,
'status': file.status,
}
for file in files
],
'files': files,
}
)
@@ -86,14 +58,17 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
return self.http_status(400, -1, 'File ID is required')
# 调用服务层方法将文件与知识库关联
await self.ap.knowledge_base_service.relate_file_id_with_kb(knowledge_base_uuid, file_id)
return self.success({})
task_id = await self.ap.knowledge_service.store_file(knowledge_base_uuid, file_id)
return self.success(
{
'task_id': task_id,
}
)
@self.route(
'/<knowledge_base_uuid>/files/<file_id>',
methods=['DELETE'],
endpoint='delete_specific_file_in_kb',
)
async def delete_specific_file_in_kb(file_id: str,knowledge_base_uuid: str) -> str:
await self.ap.knowledge_base_service.delete_data_by_file_id(file_id)
async def delete_specific_file_in_kb(file_id: str, knowledge_base_uuid: str) -> str:
await self.ap.knowledge_service.delete_file(knowledge_base_uuid, file_id)
return self.success({})

View File

@@ -0,0 +1,90 @@
from __future__ import annotations
import uuid
import sqlalchemy
from ....core import app
from ....entity.persistence import rag as persistence_rag
class KnowledgeService:
"""知识库服务"""
ap: app.Application
def __init__(self, ap: app.Application) -> None:
self.ap = ap
async def get_knowledge_bases(self) -> list[dict]:
"""获取所有知识库"""
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase))
knowledge_bases = result.all()
return [
self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, knowledge_base)
for knowledge_base in knowledge_bases
]
async def get_knowledge_base(self, kb_uuid: str) -> dict | None:
"""获取知识库"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
)
knowledge_base = result.first()
if knowledge_base is None:
return None
return self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, knowledge_base)
async def create_knowledge_base(self, kb_data: dict) -> str:
"""创建知识库"""
kb_data['uuid'] = str(uuid.uuid4())
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.KnowledgeBase).values(kb_data))
kb = await self.get_knowledge_base(kb_data['uuid'])
await self.ap.rag_mgr.load_knowledge_base(kb)
return kb_data['uuid']
async def update_knowledge_base(self, kb_uuid: str, kb_data: dict) -> None:
"""更新知识库"""
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_rag.KnowledgeBase)
.values(kb_data)
.where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
)
await self.ap.rag_mgr.remove_knowledge_base(kb_uuid)
kb = await self.get_knowledge_base(kb_uuid)
await self.ap.rag_mgr.load_knowledge_base(kb)
async def store_file(self, kb_uuid: str, file_id: str) -> int:
"""存储文件"""
# await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(kb_id=kb_uuid, file_id=file_id))
# await self.ap.rag_mgr.store_file(file_id)
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
if runtime_kb is None:
raise Exception('Knowledge base not found')
return await runtime_kb.store_file(file_id)
async def get_files_by_knowledge_base(self, kb_uuid: str) -> list[dict]:
"""获取知识库文件"""
result = await self.ap.persistence_mgr.execute_async(
sqlalchemy.select(persistence_rag.File).where(persistence_rag.File.kb_id == kb_uuid)
)
files = result.all()
return [self.ap.persistence_mgr.serialize_model(persistence_rag.File, file) for file in files]
async def delete_file(self, kb_uuid: str, file_id: str) -> None:
"""删除文件"""
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file_id)
)
# TODO: remove from memory
async def delete_knowledge_base(self, kb_uuid: str) -> None:
"""删除知识库"""
await self.ap.persistence_mgr.execute_async(
sqlalchemy.delete(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
)
# TODO: remove from memory

View File

@@ -22,6 +22,7 @@ from ..api.http.service import user as user_service
from ..api.http.service import model as model_service
from ..api.http.service import pipeline as pipeline_service
from ..api.http.service import bot as bot_service
from ..api.http.service import knowledge as knowledge_service
from ..discover import engine as discover_engine
from ..storage import mgr as storagemgr
from ..utils import logcache
@@ -48,6 +49,8 @@ class Application:
model_mgr: llm_model_mgr.ModelManager = None
rag_mgr: rag_mgr.RAGManager = None
# TODO 移动到 pipeline 里
tool_mgr: llm_tool_mgr.ToolManager = None
@@ -112,7 +115,7 @@ class Application:
bot_service: bot_service.BotService = None
knowledge_base_service: rag_mgr.RAGManager = None
knowledge_service: knowledge_service.KnowledgeService = None
def __init__(self):
pass

View File

@@ -17,6 +17,7 @@ from ...api.http.service import user as user_service
from ...api.http.service import model as model_service
from ...api.http.service import pipeline as pipeline_service
from ...api.http.service import bot as bot_service
from ...api.http.service import knowledge as knowledge_service
from ...discover import engine as discover_engine
from ...storage import mgr as storagemgr
from ...utils import logcache
@@ -89,6 +90,10 @@ class BuildAppStage(stage.BootingStage):
await pipeline_mgr.initialize()
ap.pipeline_mgr = pipeline_mgr
rag_mgr_inst = rag_mgr.RAGManager(ap)
await rag_mgr_inst.initialize_rag_system()
ap.rag_mgr = rag_mgr_inst
http_ctrl = http_controller.HTTPController(ap)
await http_ctrl.initialize()
ap.http_ctrl = http_ctrl
@@ -102,15 +107,14 @@ class BuildAppStage(stage.BootingStage):
embedding_models_service_inst = model_service.EmbeddingModelsService(ap)
ap.embedding_models_service = embedding_models_service_inst
knowledge_base_service_inst = rag_mgr.RAGManager(ap)
await knowledge_base_service_inst.initialize_rag_system()
ap.knowledge_base_service = knowledge_base_service_inst
pipeline_service_inst = pipeline_service.PipelineService(ap)
ap.pipeline_service = pipeline_service_inst
bot_service_inst = bot_service.BotService(ap)
ap.bot_service = bot_service_inst
knowledge_service_inst = knowledge_service.KnowledgeService(ap)
ap.knowledge_service = knowledge_service_inst
ctrl = controller.Controller(ap)
ap.ctrl = ctrl

View File

@@ -1,51 +1,50 @@
from sqlalchemy import create_engine, Column, String, Text, DateTime, LargeBinary, Integer
from sqlalchemy.orm import declarative_base, sessionmaker
from datetime import datetime
import os
import uuid
import sqlalchemy
from .base import Base
Base = declarative_base()
DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./rag_knowledge.db')
print("Using database URL:", DATABASE_URL)
# Base = declarative_base()
# DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./rag_knowledge.db')
# print("Using database URL:", DATABASE_URL)
engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False})
# engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def create_db_and_tables():
"""Creates all database tables defined in the Base."""
Base.metadata.create_all(bind=engine)
print('Database tables created or already exist.')
# def create_db_and_tables():
# """Creates all database tables defined in the Base."""
# Base.metadata.create_all(bind=engine)
# print('Database tables created or already exist.')
class KnowledgeBase(Base):
__tablename__ = 'kb'
id = Column(String, primary_key=True, index=True)
name = Column(String, index=True)
description = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
embedding_model_uuid = Column(String, default='')
top_k = Column(Integer, default=5)
__tablename__ = 'knowledge_bases'
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
name = sqlalchemy.Column(sqlalchemy.String, index=True)
description = sqlalchemy.Column(sqlalchemy.Text)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now())
embedding_model_uuid = sqlalchemy.Column(sqlalchemy.String, default='')
top_k = sqlalchemy.Column(sqlalchemy.Integer, default=5)
class File(Base):
__tablename__ = 'file'
id = Column(String, primary_key=True, index=True)
kb_id = Column(String, nullable=True)
file_name = Column(String)
path = Column(String)
created_at = Column(DateTime, default=datetime.utcnow)
file_type = Column(String)
status = Column(Integer, default=0) # 0: uploaded and processing, 1: completed, 2: failed
__tablename__ = 'knowledge_base_files'
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
kb_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
file_name = sqlalchemy.Column(sqlalchemy.String)
extension = sqlalchemy.Column(sqlalchemy.String)
created_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now())
status = sqlalchemy.Column(sqlalchemy.String, default='pending') # pending, processing, completed, failed
class Chunk(Base):
__tablename__ = 'chunks'
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
file_id = Column(String, nullable=True)
text = Column(Text)
__tablename__ = 'knowledge_base_chunks'
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
file_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
text = sqlalchemy.Column(sqlalchemy.Text)
class Vector(Base):
__tablename__ = 'vectors'
id = Column(String, primary_key=True, index=True)
chunk_id = Column(String, nullable=True)
embedding = Column(LargeBinary)
# class Vector(Base):
# __tablename__ = 'knowledge_base_vectors'
# uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
# chunk_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
# embedding = sqlalchemy.Column(sqlalchemy.LargeBinary)

View File

@@ -119,7 +119,7 @@ class EventLogger:
async def _truncate_logs(self):
if len(self.logs) > MAX_LOG_COUNT:
for i in range(DELETE_COUNT_PER_TIME):
for image_key in self.logs[i].images:
for image_key in self.logs[i].images: # type: ignore
await self.ap.storage_mgr.storage_provider.delete(image_key)
self.logs = self.logs[DELETE_COUNT_PER_TIME:]

View File

@@ -1,149 +1,189 @@
from __future__ import annotations
import os
import asyncio
import traceback
import uuid
from pkg.rag.knowledge.services.parser import FileParser
from pkg.rag.knowledge.services.chunker import Chunker
from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk
from pkg.rag.knowledge.services.database import (
KnowledgeBase,
File,
Chunk,
)
from pkg.core import app
from pkg.rag.knowledge.services.embedder import Embedder
from pkg.rag.knowledge.services.retriever import Retriever
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
from ...entity.persistence import model as persistence_model
from pkg.core import taskmgr
from ...entity.persistence import rag as persistence_rag
import sqlalchemy
class RuntimeKnowledgeBase:
ap: app.Application
knowledge_base_entity: persistence_rag.KnowledgeBase
chroma_manager: ChromaIndexManager
parser: FileParser
chunker: Chunker
embedder: Embedder
retriever: Retriever
def __init__(self, ap: app.Application, knowledge_base_entity: persistence_rag.KnowledgeBase):
self.ap = ap
self.knowledge_base_entity = knowledge_base_entity
self.chroma_manager = ChromaIndexManager(ap=self.ap)
self.parser = FileParser(ap=self.ap)
self.chunker = Chunker(ap=self.ap)
self.embedder = Embedder(ap=self.ap, chroma_manager=self.chroma_manager)
self.retriever = Retriever(ap=self.ap, chroma_manager=self.chroma_manager)
async def initialize(self):
pass
async def _store_file_task(self, file: persistence_rag.File, task_context: taskmgr.TaskContext):
try:
# set file status to processing
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_rag.File)
.where(persistence_rag.File.uuid == file.uuid)
.values(status='processing')
)
task_context.set_current_action('Parsing file')
# parse file
text = await self.parser.parse(file.file_name, file.extension)
if not text:
raise Exception(f'No text extracted from file {file.file_name}')
task_context.set_current_action('Chunking file')
# chunk file
chunks_texts = await self.chunker.chunk(text)
if not chunks_texts:
raise Exception(f'No chunks extracted from file {file.file_name}')
task_context.set_current_action('Embedding chunks')
# embed chunks
await self.embedder.embed_and_store(
file_id=file.uuid, chunks=chunks_texts, embedding_model=self.knowledge_base_entity.embedding_model_uuid
)
# set file status to completed
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_rag.File)
.where(persistence_rag.File.uuid == file.uuid)
.values(status='completed')
)
except Exception as e:
self.ap.logger.error(f'Error storing file {file.file_id}: {e}')
# set file status to failed
await self.ap.persistence_mgr.execute_async(
sqlalchemy.update(persistence_rag.File)
.where(persistence_rag.File.uuid == file.uuid)
.values(status='failed')
)
raise
async def store_file(self, file_id: str) -> int:
# pre checking
if not await self.ap.storage_mgr.storage_provider.exists(file_id):
raise Exception(f'File {file_id} not found')
file_uuid = str(uuid.uuid4())
kb_id = self.knowledge_base_entity.uuid
file_name = file_id
extension = os.path.splitext(file_id)[1].lstrip('.')
file = persistence_rag.File(
uuid=file_uuid,
kb_id=kb_id,
file_name=file_name,
extension=extension,
status='pending',
)
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(**file.to_dict()))
# run background task asynchronously
ctx = taskmgr.TaskContext.new()
wrapper = self.ap.task_mgr.create_user_task(
self._store_file_task(file, task_context=ctx),
kind='knowledge-operation',
name=f'knowledge-store-file-{file_id}',
label=f'Store file {file_id}',
context=ctx,
)
return wrapper.id
async def dispose(self):
pass
class RAGManager:
ap: app.Application
knowledge_bases: list[RuntimeKnowledgeBase]
def __init__(self, ap: app.Application):
self.ap = ap
self.chroma_manager = ChromaIndexManager()
self.parser = FileParser()
self.chunker = Chunker()
self.embedder = Embedder(ap=self.ap, chroma_manager=self.chroma_manager)
self.retriever = Retriever(ap=self.ap, chroma_manager=self.chroma_manager)
self.knowledge_bases = []
async def initialize_rag_system(self):
"""Initializes the RAG system by creating database tables."""
await asyncio.to_thread(create_db_and_tables)
async def initialize(self):
pass
async def create_knowledge_base(
self, kb_name: str, kb_description: str, embedding_model_uuid: str = '', top_k: int = 5
):
"""
Creates a new knowledge base if it doesn't already exist.
"""
try:
if not kb_name:
raise ValueError('Knowledge base name must be set while creating.')
async def load_knowledge_bases_from_db(self):
self.ap.logger.info('Loading knowledge bases from db...')
def _create_kb_sync():
session = SessionLocal()
try:
kb = session.query(KnowledgeBase).filter_by(name=kb_name).first()
if not kb:
id = str(uuid.uuid4())
new_kb = KnowledgeBase(
name=kb_name,
description=kb_description,
embedding_model_uuid=embedding_model_uuid,
top_k=top_k,
id=id,
)
session.add(new_kb)
session.commit()
session.refresh(new_kb)
self.ap.logger.info(f"Knowledge Base '{kb_name}' created.")
print(embedding_model_uuid)
return new_kb.id
else:
self.ap.logger.info(f"Knowledge Base '{kb_name}' already exists.")
except Exception as e:
session.rollback()
self.ap.logger.error(f"Error in _create_kb_sync for '{kb_name}': {str(e)}", exc_info=True)
raise
finally:
session.close()
self.knowledge_bases = []
return await asyncio.to_thread(_create_kb_sync)
except Exception as e:
self.ap.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True)
raise
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase))
async def get_all_knowledge_bases(self):
"""
Retrieves all knowledge bases from the database.
"""
try:
knowledge_bases = result.all()
def _get_all_kbs_sync():
session = SessionLocal()
try:
return session.query(KnowledgeBase).all()
finally:
session.close()
for knowledge_base in knowledge_bases:
try:
await self.load_knowledge_base(knowledge_base)
except Exception as e:
self.ap.logger.error(
f'Error loading knowledge base {knowledge_base.uuid}: {e}\n{traceback.format_exc()}'
)
return await asyncio.to_thread(_get_all_kbs_sync)
except Exception as e:
self.ap.logger.error(f'Error retrieving knowledge bases: {str(e)}', exc_info=True)
return []
async def load_knowledge_base(
self,
knowledge_base_entity: persistence_rag.KnowledgeBase | sqlalchemy.Row | dict,
) -> RuntimeKnowledgeBase:
if isinstance(knowledge_base_entity, sqlalchemy.Row):
knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity._mapping)
elif isinstance(knowledge_base_entity, dict):
knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity)
async def get_knowledge_base_by_id(self, kb_id: str):
"""
Retrieves a specific knowledge base by its ID.
"""
try:
runtime_knowledge_base = RuntimeKnowledgeBase(ap=self.ap, knowledge_base_entity=knowledge_base_entity)
def _get_kb_sync(kb_id_param):
session = SessionLocal()
try:
return session.query(KnowledgeBase).filter_by(id=kb_id_param).first()
finally:
session.close()
await runtime_knowledge_base.initialize()
return await asyncio.to_thread(_get_kb_sync, kb_id)
except Exception as e:
self.ap.logger.error(f'Error retrieving knowledge base with ID {kb_id}: {str(e)}', exc_info=True)
return None
self.knowledge_bases.append(runtime_knowledge_base)
async def get_files_by_knowledge_base(self, kb_id: str):
"""
Retrieves files associated with a specific knowledge base by querying the File table directly.
"""
try:
return runtime_knowledge_base
def _get_files_sync(kb_id_param):
session = SessionLocal()
try:
return session.query(File).filter_by(kb_id=kb_id_param).all()
finally:
session.close()
async def get_knowledge_base_by_uuid(self, kb_uuid: str) -> RuntimeKnowledgeBase | None:
for kb in self.knowledge_bases:
if kb.knowledge_base_entity.uuid == kb_uuid:
return kb
return None
return await asyncio.to_thread(_get_files_sync, kb_id)
except Exception as e:
self.ap.logger.error(f'Error retrieving files for knowledge base ID {kb_id}: {str(e)}', exc_info=True)
return []
async def get_all_files(self):
"""
Retrieves all files stored in the database, regardless of their association
with any specific knowledge base.
"""
try:
def _get_all_files_sync():
session = SessionLocal()
try:
return session.query(File).all()
finally:
session.close()
return await asyncio.to_thread(_get_all_files_sync)
except Exception as e:
self.ap.logger.error(f'Error retrieving all files: {str(e)}', exc_info=True)
return []
async def remove_knowledge_base(self, kb_uuid: str):
for kb in self.knowledge_bases:
if kb.knowledge_base_entity.uuid == kb_uuid:
await kb.dispose()
self.knowledge_bases.remove(kb)
return
async def store_data(self, file_path: str, kb_id: str, file_type: str, file_id: str = None):
"""
@@ -220,7 +260,8 @@ class RAGManager:
await self.ap.storage_mgr.storage_provider.delete(file_id)
except Exception as e:
self.ap.logger.error(
f'Error deleting file from storage for file_id {file_id}: {str(e)}', exc_info=True
f'Error deleting file from storage for file_id {file_id}: {str(e)}',
exc_info=True,
)
self.ap.logger.info(f'Deleted file record for file_id: {file_id}')
else:
@@ -273,7 +314,10 @@ class RAGManager:
)
except Exception as kb_del_e:
session.rollback()
self.ap.logger.error(f'Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}', exc_info=True)
self.ap.logger.error(
f'Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}',
exc_info=True,
)
raise
finally:
session.close()
@@ -283,7 +327,8 @@ class RAGManager:
if session.is_active:
session.rollback()
self.ap.logger.error(
f'Error during overall knowledge base deletion for ID {kb_id}: {str(e)}', exc_info=True
f'Error during overall knowledge base deletion for ID {kb_id}: {str(e)}',
exc_info=True,
)
raise
finally:

View File

@@ -1,43 +1,43 @@
import numpy as np
import logging
from chromadb import PersistentClient
import os
from pkg.core import app
logger = logging.getLogger(__name__)
class ChromaIndexManager:
def __init__(self, collection_name: str = "default_collection"):
self.logger = logging.getLogger(self.__class__.__name__)
chroma_data_path = os.path.abspath(os.path.join(__file__, "../../../../../../data/chroma"))
os.makedirs(chroma_data_path, exist_ok=True)
def __init__(self, ap: app.Application, collection_name: str = 'default_collection'):
self.ap = ap
chroma_data_path = './data/chroma'
self.client = PersistentClient(path=chroma_data_path)
self._collection_name = collection_name
self._collection = None
self.logger.info(f"ChromaIndexManager initialized. Collection name: {self._collection_name}")
self.ap.logger.info(f'ChromaIndexManager initialized. Collection name: {self._collection_name}')
@property
def collection(self):
if self._collection is None:
self._collection = self.client.get_or_create_collection(name=self._collection_name)
self.logger.info(f"Chroma collection '{self._collection_name}' accessed/created.")
self.ap.logger.info(f"Chroma collection '{self._collection_name}' accessed/created.")
return self._collection
def add_embeddings_sync(self, file_ids: list[int], chunk_ids: list[int], embeddings: np.ndarray, documents: list[str]):
if embeddings.shape[0] != len(chunk_ids) or embeddings.shape[0] != len(file_ids) or embeddings.shape[0] != len(documents):
raise ValueError("Embedding, file_id, chunk_id, and document count mismatch.")
def add_embeddings_sync(
self, file_ids: list[int], chunk_ids: list[int], embeddings: np.ndarray, documents: list[str]
):
if (
embeddings.shape[0] != len(chunk_ids)
or embeddings.shape[0] != len(file_ids)
or embeddings.shape[0] != len(documents)
):
raise ValueError('Embedding, file_id, chunk_id, and document count mismatch.')
chroma_ids = [f"{file_id}_{chunk_id}" for file_id, chunk_id in zip(file_ids, chunk_ids)]
metadatas = [{"file_id": fid, "chunk_id": cid} for fid, cid in zip(file_ids, chunk_ids)]
chroma_ids = [f'{file_id}_{chunk_id}' for file_id, chunk_id in zip(file_ids, chunk_ids)]
metadatas = [{'file_id': fid, 'chunk_id': cid} for fid, cid in zip(file_ids, chunk_ids)]
self.logger.debug(f"Adding {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.")
self.collection.add(
embeddings=embeddings.tolist(),
ids=chroma_ids,
metadatas=metadatas,
documents=documents
)
self.collection.add(embeddings=embeddings.tolist(), ids=chroma_ids, metadatas=metadatas, documents=documents)
self.logger.info(f"Added {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.")
def search_sync(self, query_embedding: np.ndarray, k: int = 5):
@@ -54,12 +54,14 @@ class ChromaIndexManager:
query_embeddings=query_embedding.tolist(),
n_results=k,
# REMOVE 'ids' from the include list. It's returned by default.
include=["metadatas", "distances", "documents"]
include=['metadatas', 'distances', 'documents'],
)
self.logger.debug(f"Chroma search returned {len(results.get('ids', [[]])[0])} results.")
self.logger.debug(f'Chroma search returned {len(results.get("ids", [[]])[0])} results.')
return results
def delete_by_file_id_sync(self, file_id: int):
self.logger.info(f"Deleting embeddings for file_id: {file_id} from Chroma collection '{self._collection_name}'.")
self.collection.delete(where={"file_id": file_id})
self.logger.info(f"Deleted embeddings for file_id: {file_id} from Chroma.")
self.logger.info(
f"Deleting embeddings for file_id: {file_id} from Chroma collection '{self._collection_name}'."
)
self.collection.delete(where={'file_id': file_id})
self.logger.info(f'Deleted embeddings for file_id: {file_id} from Chroma.')

View File

@@ -1,21 +1,26 @@
# services/chunker.py
import logging
from typing import List
from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync
from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync
from pkg.core import app
logger = logging.getLogger(__name__)
class Chunker(BaseService):
"""
A class for splitting long texts into smaller, overlapping chunks.
"""
def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50):
super().__init__() # Initialize BaseService
self.logger = logging.getLogger(self.__class__.__name__)
def __init__(self, ap: app.Application, chunk_size: int = 500, chunk_overlap: int = 50):
super().__init__(ap) # Initialize BaseService
self.ap = ap
self.chunk_size = chunk_size
self.chunk_overlap = chunk_overlap
if self.chunk_overlap >= self.chunk_size:
self.logger.warning("Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.")
self.logger.warning(
'Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.'
)
def _split_text_sync(self, text: str) -> List[str]:
"""
@@ -27,18 +32,19 @@ class Chunker(BaseService):
# words = text.split()
# chunks = []
# current_chunk = []
# for word in words:
# current_chunk.append(word)
# if len(current_chunk) > self.chunk_size:
# chunks.append(" ".join(current_chunk[:self.chunk_size]))
# current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:]
# if current_chunk:
# chunks.append(" ".join(current_chunk))
# A more robust chunking strategy (e.g., using recursive character text splitter)
from langchain.text_splitter import RecursiveCharacterTextSplitter
text_splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size,
chunk_overlap=self.chunk_overlap,
@@ -51,8 +57,8 @@ class Chunker(BaseService):
"""
Asynchronously chunks a given text into smaller pieces.
"""
self.logger.info(f"Chunking text (length: {len(text)})...")
self.ap.logger.info(f'Chunking text (length: {len(text)})...')
# Run the synchronous splitting logic in a separate thread
chunks = await self._run_sync(self._split_text_sync, text)
self.logger.info(f"Text chunked into {len(chunks)} pieces.")
return chunks
self.ap.logger.info(f'Text chunked into {len(chunks)} pieces.')
return chunks

View File

@@ -7,16 +7,10 @@ from sqlalchemy.orm import Session
from pkg.rag.knowledge.services.base_service import BaseService
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
from sqlalchemy.orm import declarative_base, sessionmaker
from ....core import app
from ....entity.persistence import model as persistence_model
import sqlalchemy
from ....provider.modelmgr.requester import RuntimeEmbeddingModel
base = declarative_base()
logger = logging.getLogger(__name__)
class Embedder(BaseService):
def __init__(self, ap: app.Application, chroma_manager: ChromaIndexManager = None) -> None:
super().__init__()
@@ -30,61 +24,66 @@ class Embedder(BaseService):
This function assumes it's called within a context where the session
will be committed/rolled back and closed by the caller.
"""
self.logger.debug(f"Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).")
self.logger.debug(f'Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).')
chunk_objects = []
for text in chunks_texts:
chunk = Chunk(file_id=file_id, text=text)
session.add(chunk)
chunk_objects.append(chunk)
session.flush() # This populates the .id attribute for each new chunk object
self.logger.debug(f"Successfully added {len(chunk_objects)} chunk entries to DB.")
session.flush() # This populates the .id attribute for each new chunk object
self.logger.debug(f'Successfully added {len(chunk_objects)} chunk entries to DB.')
return chunk_objects
async def embed_and_store(self, file_id: int, chunks: List[str], embedding_model: RuntimeEmbeddingModel) -> List[Chunk]:
async def embed_and_store(
self, file_id: int, chunks: List[str], embedding_model: RuntimeEmbeddingModel
) -> List[Chunk]:
if not embedding_model:
raise RuntimeError("Embedding model not loaded. Please check Embedder initialization.")
raise RuntimeError('Embedding model not loaded. Please check Embedder initialization.')
session = SessionLocal() # Start a session that will live for the whole operation
session = SessionLocal() # Start a session that will live for the whole operation
chunk_objects = []
try:
# 1. Save chunks to the relational database first to get their IDs
# We call _db_save_chunks_sync directly without _run_sync's session management
# because we manage the session here across multiple async calls.
chunk_objects = await asyncio.to_thread(self._db_save_chunks_sync, session, file_id, chunks)
session.commit() # Commit chunks to make their IDs permanent and accessible
session.commit() # Commit chunks to make their IDs permanent and accessible
if not chunk_objects:
self.logger.warning(f"No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.")
self.logger.warning(
f'No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.'
)
return []
# get the embeddings for the chunks
embeddings = []
i = 0
while i <len(chunks):
chunk = chunks[i]
embeddings: list[list[float]] = []
for chunk in chunks:
result = await embedding_model.requester.invoke_embedding(
model=embedding_model,
input_text=chunk,
)
embeddings.append(result)
i += 1
embeddings_np = np.array(embeddings, dtype=np.float32)
self.logger.info("Saving embeddings to Chroma...")
chunk_ids = [c.id for c in chunk_objects]
self.logger.info('Saving embeddings to Chroma...')
chunk_ids = [c.id for c in chunk_objects]
file_ids_for_chroma = [file_id] * len(chunk_ids)
await self._run_sync( # Use _run_sync for the Chroma operation, as it's a sync call
await self._run_sync( # Use _run_sync for the Chroma operation, as it's a sync call
self.chroma_manager.add_embeddings_sync,
file_ids_for_chroma, chunk_ids, embeddings_np, chunks # Pass original chunks texts for documents
file_ids_for_chroma,
chunk_ids,
embeddings_np,
chunks, # Pass original chunks texts for documents
)
self.logger.info(f"Successfully saved {len(chunk_objects)} embeddings to Chroma.")
self.logger.info(f'Successfully saved {len(chunk_objects)} embeddings to Chroma.')
return chunk_objects
except Exception as e:
session.rollback() # Rollback on any error
self.logger.error(f"Failed to process and store data for file_id {file_id}: {e}", exc_info=True)
raise # Re-raise the exception to propagate it
session.rollback() # Rollback on any error
self.logger.error(f'Failed to process and store data for file_id {file_id}: {e}', exc_info=True)
raise # Re-raise the exception to propagate it
finally:
session.close() # Ensure the session is always closed
session.close() # Ensure the session is always closed

View File

@@ -1,4 +1,5 @@
import PyPDF2
import io
from docx import Document
import pandas as pd
import chardet
@@ -10,7 +11,7 @@ import ebooklib
from ebooklib import epub
import re
import asyncio # Import asyncio for async operations
import os
from pkg.core import app
# Configure logging
logger = logging.getLogger(__name__)
@@ -24,8 +25,8 @@ class FileParser:
to avoid blocking the asyncio event loop.
"""
def __init__(self):
self.logger = logging.getLogger(self.__class__.__name__)
def __init__(self, ap: app.Application):
self.ap = ap
async def _run_sync(self, sync_func: Callable, *args: Any, **kwargs: Any) -> Any:
"""
@@ -35,138 +36,160 @@ class FileParser:
try:
return await asyncio.to_thread(sync_func, *args, **kwargs)
except Exception as e:
self.logger.error(f'Error running synchronous function {sync_func.__name__}: {e}')
self.ap.logger.error(f'Error running synchronous function {sync_func.__name__}: {e}')
raise
async def parse(self, file_path: str) -> Union[str, None]:
async def parse(self, file_name: str, extension: str) -> Union[str, None]:
"""
Parses the file based on its extension and returns the extracted text content.
This is the main asynchronous entry point for parsing.
Args:
file_path (str): The path to the file to be parsed.
file_name (str): The name of the file to be parsed, get from ap.storage_mgr
Returns:
Union[str, None]: The extracted text content as a single string, or None if parsing fails.
"""
if not file_path or not os.path.exists(file_path):
self.logger.error(f'Invalid file path provided: {file_path}')
return None
file_extension = file_path.split('.')[-1].lower()
file_extension = extension.lower()
parser_method = getattr(self, f'_parse_{file_extension}', None)
if parser_method is None:
self.logger.error(f'Unsupported file format: {file_extension} for file {file_path}')
self.ap.logger.error(f'Unsupported file format: {file_extension} for file {file_name}')
return None
try:
# Pass file_path to the specific parser methods
return await parser_method(file_path)
return await parser_method(file_name)
except Exception as e:
self.logger.error(f'Failed to parse {file_extension} file {file_path}: {e}')
self.ap.logger.error(f'Failed to parse {file_extension} file {file_name}: {e}')
return None
# --- Helper for reading files with encoding detection ---
async def _read_file_content(self, file_path: str, mode: str = 'r') -> Union[str, bytes]:
async def _read_file_content(self, file_name: str) -> Union[str, bytes]:
"""
Reads a file with automatic encoding detection, ensuring the synchronous
file read operation runs in a separate thread.
"""
def _read_sync():
with open(file_path, 'rb') as file:
raw_data = file.read()
detected = chardet.detect(raw_data)
encoding = detected['encoding'] or 'utf-8'
# def _read_sync():
# with open(file_path, 'rb') as file:
# raw_data = file.read()
# detected = chardet.detect(raw_data)
# encoding = detected['encoding'] or 'utf-8'
if mode == 'r':
return raw_data.decode(encoding, errors='ignore')
return raw_data # For binary mode
# if mode == 'r':
# return raw_data.decode(encoding, errors='ignore')
# return raw_data # For binary mode
return await self._run_sync(_read_sync)
# return await self._run_sync(_read_sync)
file_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
detected = chardet.detect(file_bytes)
encoding = detected['encoding'] or 'utf-8'
return file_bytes.decode(encoding, errors='ignore')
# --- Specific Parser Methods ---
async def _parse_txt(self, file_path: str) -> str:
async def _parse_txt(self, file_name: str) -> str:
"""Parses a TXT file and returns its content."""
self.logger.info(f'Parsing TXT file: {file_path}')
return await self._read_file_content(file_path, mode='r')
self.ap.logger.info(f'Parsing TXT file: {file_name}')
return await self._read_file_content(file_name)
async def _parse_pdf(self, file_path: str) -> str:
async def _parse_pdf(self, file_name: str) -> str:
"""Parses a PDF file and returns its text content."""
self.logger.info(f'Parsing PDF file: {file_path}')
self.ap.logger.info(f'Parsing PDF file: {file_name}')
# def _parse_pdf_sync():
# text_content = []
# with open(file_name, 'rb') as file:
# pdf_reader = PyPDF2.PdfReader(file)
# for page in pdf_reader.pages:
# text = page.extract_text()
# if text:
# text_content.append(text)
# return '\n'.join(text_content)
# return await self._run_sync(_parse_pdf_sync)
pdf_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
def _parse_pdf_sync():
pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
text_content = []
with open(file_path, 'rb') as file:
pdf_reader = PyPDF2.PdfReader(file)
for page in pdf_reader.pages:
text = page.extract_text()
if text:
text_content.append(text)
for page in pdf_reader.pages:
text = page.extract_text()
if text:
text_content.append(text)
return '\n'.join(text_content)
return await self._run_sync(_parse_pdf_sync)
async def _parse_docx(self, file_path: str) -> str:
async def _parse_docx(self, file_name: str) -> str:
"""Parses a DOCX file and returns its text content."""
self.logger.info(f'Parsing DOCX file: {file_path}')
self.ap.logger.info(f'Parsing DOCX file: {file_name}')
docx_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
def _parse_docx_sync():
doc = Document(file_path)
doc = Document(io.BytesIO(docx_bytes))
text_content = [paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()]
return '\n'.join(text_content)
return await self._run_sync(_parse_docx_sync)
async def _parse_doc(self, file_path: str) -> str:
async def _parse_doc(self, file_name: str) -> str:
"""Handles .doc files, explicitly stating lack of direct support."""
self.logger.warning(f'Direct .doc parsing is not supported for {file_path}. Please convert to .docx first.')
self.ap.logger.warning(f'Direct .doc parsing is not supported for {file_name}. Please convert to .docx first.')
raise NotImplementedError('Direct .doc parsing not supported. Please convert to .docx first.')
async def _parse_xlsx(self, file_path: str) -> str:
async def _parse_xlsx(self, file_name: str) -> str:
"""Parses an XLSX file, returning text from all sheets."""
self.logger.info(f'Parsing XLSX file: {file_path}')
self.ap.logger.info(f'Parsing XLSX file: {file_name}')
xlsx_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
def _parse_xlsx_sync():
excel_file = pd.ExcelFile(file_path)
excel_file = pd.ExcelFile(io.BytesIO(xlsx_bytes))
all_sheet_content = []
for sheet_name in excel_file.sheet_names:
df = pd.read_excel(file_path, sheet_name=sheet_name)
df = pd.read_excel(io.BytesIO(xlsx_bytes), sheet_name=sheet_name)
sheet_text = f'--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n'
all_sheet_content.append(sheet_text)
return '\n'.join(all_sheet_content)
return await self._run_sync(_parse_xlsx_sync)
async def _parse_csv(self, file_path: str) -> str:
async def _parse_csv(self, file_name: str) -> str:
"""Parses a CSV file and returns its content as a string."""
self.logger.info(f'Parsing CSV file: {file_path}')
self.ap.logger.info(f'Parsing CSV file: {file_name}')
csv_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
def _parse_csv_sync():
# pd.read_csv can often detect encoding, but explicit detection is safer
raw_data = self._read_file_content(
file_path, mode='rb'
) # Note: this will need to be await outside this sync function
_ = raw_data
# raw_data = self._read_file_content(
# file_name, mode='rb'
# ) # Note: this will need to be await outside this sync function
# _ = raw_data
# For simplicity, we'll let pandas handle encoding internally after a raw read.
# A more robust solution might pass encoding directly to pd.read_csv after detection.
detected = chardet.detect(open(file_path, 'rb').read())
detected = chardet.detect(io.BytesIO(csv_bytes))
encoding = detected['encoding'] or 'utf-8'
df = pd.read_csv(file_path, encoding=encoding)
df = pd.read_csv(io.BytesIO(csv_bytes), encoding=encoding)
return df.to_string(index=False)
return await self._run_sync(_parse_csv_sync)
async def _parse_markdown(self, file_path: str) -> str:
async def _parse_markdown(self, file_name: str) -> str:
"""Parses a Markdown file, converting it to structured plain text."""
self.logger.info(f'Parsing Markdown file: {file_path}')
self.ap.logger.info(f'Parsing Markdown file: {file_name}')
md_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
def _parse_markdown_sync():
md_content = self._read_file_content(
file_path, mode='r'
) # This is a synchronous call within a sync function
md_content = io.BytesIO(md_bytes).read().decode('utf-8', errors='ignore')
html_content = markdown.markdown(
md_content, extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code']
)
@@ -200,12 +223,14 @@ class FileParser:
return await self._run_sync(_parse_markdown_sync)
async def _parse_html(self, file_path: str) -> str:
async def _parse_html(self, file_name: str) -> str:
"""Parses an HTML file, extracting structured plain text."""
self.logger.info(f'Parsing HTML file: {file_path}')
self.ap.logger.info(f'Parsing HTML file: {file_name}')
html_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
def _parse_html_sync():
html_content = self._read_file_content(file_path, mode='r') # Sync call within sync function
html_content = io.BytesIO(html_bytes).read().decode('utf-8', errors='ignore')
soup = BeautifulSoup(html_content, 'html.parser')
for script_or_style in soup(['script', 'style']):
script_or_style.decompose()
@@ -236,12 +261,14 @@ class FileParser:
return await self._run_sync(_parse_html_sync)
async def _parse_epub(self, file_path: str) -> str:
async def _parse_epub(self, file_name: str) -> str:
"""Parses an EPUB file, extracting metadata and content."""
self.logger.info(f'Parsing EPUB file: {file_path}')
self.ap.logger.info(f'Parsing EPUB file: {file_name}')
epub_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
def _parse_epub_sync():
book = epub.read_epub(file_path)
book = epub.read_epub(io.BytesIO(epub_bytes))
text_content = []
title_meta = book.get_metadata('DC', 'title')
if title_meta:

0
pkg/vector/__init__.py Normal file
View File

13
pkg/vector/mgr.py Normal file
View File

@@ -0,0 +1,13 @@
from __future__ import annotations
from ..core import app
class VectorDBManager:
ap: app.Application
def __init__(self, ap: app.Application):
self.ap = ap
async def initialize(self):
pass

7
pkg/vector/vdb.py Normal file
View File

@@ -0,0 +1,7 @@
from __future__ import annotations
import abc
class VectorDatabase(abc.ABC):
pass