mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 03:15:06 +08:00
chore: stash
This commit is contained in:
@@ -14,7 +14,7 @@ preregistered_groups: list[type[RouterGroup]] = []
|
||||
"""RouterGroup 的预注册列表"""
|
||||
|
||||
|
||||
def group_class(name: str, path: str) -> None:
|
||||
def group_class(name: str, path: str) -> typing.Callable[[typing.Type[RouterGroup]], typing.Type[RouterGroup]]:
|
||||
"""注册一个 RouterGroup"""
|
||||
|
||||
def decorator(cls: typing.Type[RouterGroup]) -> typing.Type[RouterGroup]:
|
||||
@@ -120,6 +120,6 @@ class RouterGroup(abc.ABC):
|
||||
}
|
||||
)
|
||||
|
||||
def http_status(self, status: int, code: int, msg: str) -> quart.Response:
|
||||
def http_status(self, status: int, code: int, msg: str) -> typing.Tuple[quart.Response, int]:
|
||||
"""返回一个指定状态码的响应"""
|
||||
return self.fail(code, msg), status
|
||||
return (self.fail(code, msg), status)
|
||||
|
||||
@@ -5,77 +5,49 @@ from ... import group
|
||||
@group.group_class('knowledge_base', '/api/v1/knowledge/bases')
|
||||
class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
async def initialize(self) -> None:
|
||||
@self.route('', methods=['POST', 'GET'], endpoint='handle_knowledge_bases')
|
||||
async def handle_knowledge_bases() -> str:
|
||||
@self.route('', methods=['POST', 'GET'])
|
||||
async def handle_knowledge_bases() -> quart.Response:
|
||||
if quart.request.method == 'GET':
|
||||
knowledge_bases = await self.ap.knowledge_base_service.get_all_knowledge_bases()
|
||||
bases_list = [
|
||||
{
|
||||
'uuid': kb.id,
|
||||
'name': kb.name,
|
||||
'description': kb.description,
|
||||
'embedding_model_uuid': kb.embedding_model_uuid,
|
||||
'top_k': kb.top_k,
|
||||
}
|
||||
for kb in knowledge_bases
|
||||
]
|
||||
return self.success(data={'bases': bases_list})
|
||||
knowledge_bases = await self.ap.knowledge_service.get_knowledge_bases()
|
||||
return self.success(data={'bases': knowledge_bases})
|
||||
|
||||
elif quart.request.method == 'POST':
|
||||
json_data = await quart.request.json
|
||||
knowledge_base_uuid = await self.ap.knowledge_base_service.create_knowledge_base(
|
||||
json_data.get('name'),
|
||||
json_data.get('description'),
|
||||
json_data.get('embedding_model_uuid'),
|
||||
json_data.get('top_k',5),
|
||||
)
|
||||
knowledge_base_uuid = await self.ap.knowledge_service.create_knowledge_base(json_data)
|
||||
return self.success(data={'uuid': knowledge_base_uuid})
|
||||
|
||||
return self.http_status(405, -1, 'Method not allowed')
|
||||
|
||||
@self.route(
|
||||
'/<knowledge_base_uuid>',
|
||||
methods=['GET', 'DELETE'],
|
||||
endpoint='handle_specific_knowledge_base',
|
||||
)
|
||||
async def handle_specific_knowledge_base(knowledge_base_uuid: str) -> str:
|
||||
async def handle_specific_knowledge_base(knowledge_base_uuid: str) -> quart.Response:
|
||||
if quart.request.method == 'GET':
|
||||
knowledge_base = await self.ap.knowledge_base_service.get_knowledge_base_by_id(knowledge_base_uuid)
|
||||
knowledge_base = await self.ap.knowledge_service.get_knowledge_base(knowledge_base_uuid)
|
||||
|
||||
if knowledge_base is None:
|
||||
return self.http_status(404, -1, 'knowledge base not found')
|
||||
|
||||
return self.success(
|
||||
data={
|
||||
'base': {
|
||||
'name': knowledge_base.name,
|
||||
'description': knowledge_base.description,
|
||||
'uuid': knowledge_base.id,
|
||||
'embedding_model_uuid': knowledge_base.embedding_model_uuid,
|
||||
'top_k': knowledge_base.top_k,
|
||||
},
|
||||
'base': knowledge_base,
|
||||
}
|
||||
)
|
||||
elif quart.request.method == 'DELETE':
|
||||
await self.ap.knowledge_base_service.delete_kb_by_id(knowledge_base_uuid)
|
||||
await self.ap.knowledge_service.delete_knowledge_base(knowledge_base_uuid)
|
||||
return self.success({})
|
||||
|
||||
@self.route(
|
||||
'/<knowledge_base_uuid>/files',
|
||||
methods=['GET', 'POST'],
|
||||
endpoint='get_knowledge_base_files',
|
||||
)
|
||||
async def get_knowledge_base_files(knowledge_base_uuid: str) -> str:
|
||||
if quart.request.method == 'GET':
|
||||
files = await self.ap.knowledge_base_service.get_files_by_knowledge_base(knowledge_base_uuid)
|
||||
files = await self.ap.knowledge_service.get_files_by_knowledge_base(knowledge_base_uuid)
|
||||
return self.success(
|
||||
data={
|
||||
'files': [
|
||||
{
|
||||
'id': file.id,
|
||||
'file_name': file.file_name,
|
||||
'status': file.status,
|
||||
}
|
||||
for file in files
|
||||
],
|
||||
'files': files,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -86,14 +58,17 @@ class KnowledgeBaseRouterGroup(group.RouterGroup):
|
||||
return self.http_status(400, -1, 'File ID is required')
|
||||
|
||||
# 调用服务层方法将文件与知识库关联
|
||||
await self.ap.knowledge_base_service.relate_file_id_with_kb(knowledge_base_uuid, file_id)
|
||||
return self.success({})
|
||||
task_id = await self.ap.knowledge_service.store_file(knowledge_base_uuid, file_id)
|
||||
return self.success(
|
||||
{
|
||||
'task_id': task_id,
|
||||
}
|
||||
)
|
||||
|
||||
@self.route(
|
||||
'/<knowledge_base_uuid>/files/<file_id>',
|
||||
methods=['DELETE'],
|
||||
endpoint='delete_specific_file_in_kb',
|
||||
)
|
||||
async def delete_specific_file_in_kb(file_id: str,knowledge_base_uuid: str) -> str:
|
||||
await self.ap.knowledge_base_service.delete_data_by_file_id(file_id)
|
||||
async def delete_specific_file_in_kb(file_id: str, knowledge_base_uuid: str) -> str:
|
||||
await self.ap.knowledge_service.delete_file(knowledge_base_uuid, file_id)
|
||||
return self.success({})
|
||||
|
||||
90
pkg/api/http/service/knowledge.py
Normal file
90
pkg/api/http/service/knowledge.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
import sqlalchemy
|
||||
|
||||
from ....core import app
|
||||
from ....entity.persistence import rag as persistence_rag
|
||||
|
||||
|
||||
class KnowledgeService:
|
||||
"""知识库服务"""
|
||||
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application) -> None:
|
||||
self.ap = ap
|
||||
|
||||
async def get_knowledge_bases(self) -> list[dict]:
|
||||
"""获取所有知识库"""
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase))
|
||||
knowledge_bases = result.all()
|
||||
return [
|
||||
self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, knowledge_base)
|
||||
for knowledge_base in knowledge_bases
|
||||
]
|
||||
|
||||
async def get_knowledge_base(self, kb_uuid: str) -> dict | None:
|
||||
"""获取知识库"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
||||
)
|
||||
knowledge_base = result.first()
|
||||
if knowledge_base is None:
|
||||
return None
|
||||
return self.ap.persistence_mgr.serialize_model(persistence_rag.KnowledgeBase, knowledge_base)
|
||||
|
||||
async def create_knowledge_base(self, kb_data: dict) -> str:
|
||||
"""创建知识库"""
|
||||
kb_data['uuid'] = str(uuid.uuid4())
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.KnowledgeBase).values(kb_data))
|
||||
|
||||
kb = await self.get_knowledge_base(kb_data['uuid'])
|
||||
|
||||
await self.ap.rag_mgr.load_knowledge_base(kb)
|
||||
|
||||
return kb_data['uuid']
|
||||
|
||||
async def update_knowledge_base(self, kb_uuid: str, kb_data: dict) -> None:
|
||||
"""更新知识库"""
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_rag.KnowledgeBase)
|
||||
.values(kb_data)
|
||||
.where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
||||
)
|
||||
await self.ap.rag_mgr.remove_knowledge_base(kb_uuid)
|
||||
|
||||
kb = await self.get_knowledge_base(kb_uuid)
|
||||
|
||||
await self.ap.rag_mgr.load_knowledge_base(kb)
|
||||
|
||||
async def store_file(self, kb_uuid: str, file_id: str) -> int:
|
||||
"""存储文件"""
|
||||
# await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(kb_id=kb_uuid, file_id=file_id))
|
||||
# await self.ap.rag_mgr.store_file(file_id)
|
||||
runtime_kb = await self.ap.rag_mgr.get_knowledge_base_by_uuid(kb_uuid)
|
||||
if runtime_kb is None:
|
||||
raise Exception('Knowledge base not found')
|
||||
return await runtime_kb.store_file(file_id)
|
||||
|
||||
async def get_files_by_knowledge_base(self, kb_uuid: str) -> list[dict]:
|
||||
"""获取知识库文件"""
|
||||
result = await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.select(persistence_rag.File).where(persistence_rag.File.kb_id == kb_uuid)
|
||||
)
|
||||
files = result.all()
|
||||
return [self.ap.persistence_mgr.serialize_model(persistence_rag.File, file) for file in files]
|
||||
|
||||
async def delete_file(self, kb_uuid: str, file_id: str) -> None:
|
||||
"""删除文件"""
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.File).where(persistence_rag.File.uuid == file_id)
|
||||
)
|
||||
# TODO: remove from memory
|
||||
|
||||
async def delete_knowledge_base(self, kb_uuid: str) -> None:
|
||||
"""删除知识库"""
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.delete(persistence_rag.KnowledgeBase).where(persistence_rag.KnowledgeBase.uuid == kb_uuid)
|
||||
)
|
||||
# TODO: remove from memory
|
||||
@@ -22,6 +22,7 @@ from ..api.http.service import user as user_service
|
||||
from ..api.http.service import model as model_service
|
||||
from ..api.http.service import pipeline as pipeline_service
|
||||
from ..api.http.service import bot as bot_service
|
||||
from ..api.http.service import knowledge as knowledge_service
|
||||
from ..discover import engine as discover_engine
|
||||
from ..storage import mgr as storagemgr
|
||||
from ..utils import logcache
|
||||
@@ -48,6 +49,8 @@ class Application:
|
||||
|
||||
model_mgr: llm_model_mgr.ModelManager = None
|
||||
|
||||
rag_mgr: rag_mgr.RAGManager = None
|
||||
|
||||
# TODO 移动到 pipeline 里
|
||||
tool_mgr: llm_tool_mgr.ToolManager = None
|
||||
|
||||
@@ -112,7 +115,7 @@ class Application:
|
||||
|
||||
bot_service: bot_service.BotService = None
|
||||
|
||||
knowledge_base_service: rag_mgr.RAGManager = None
|
||||
knowledge_service: knowledge_service.KnowledgeService = None
|
||||
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@@ -17,6 +17,7 @@ from ...api.http.service import user as user_service
|
||||
from ...api.http.service import model as model_service
|
||||
from ...api.http.service import pipeline as pipeline_service
|
||||
from ...api.http.service import bot as bot_service
|
||||
from ...api.http.service import knowledge as knowledge_service
|
||||
from ...discover import engine as discover_engine
|
||||
from ...storage import mgr as storagemgr
|
||||
from ...utils import logcache
|
||||
@@ -89,6 +90,10 @@ class BuildAppStage(stage.BootingStage):
|
||||
await pipeline_mgr.initialize()
|
||||
ap.pipeline_mgr = pipeline_mgr
|
||||
|
||||
rag_mgr_inst = rag_mgr.RAGManager(ap)
|
||||
await rag_mgr_inst.initialize_rag_system()
|
||||
ap.rag_mgr = rag_mgr_inst
|
||||
|
||||
http_ctrl = http_controller.HTTPController(ap)
|
||||
await http_ctrl.initialize()
|
||||
ap.http_ctrl = http_ctrl
|
||||
@@ -102,15 +107,14 @@ class BuildAppStage(stage.BootingStage):
|
||||
embedding_models_service_inst = model_service.EmbeddingModelsService(ap)
|
||||
ap.embedding_models_service = embedding_models_service_inst
|
||||
|
||||
knowledge_base_service_inst = rag_mgr.RAGManager(ap)
|
||||
await knowledge_base_service_inst.initialize_rag_system()
|
||||
ap.knowledge_base_service = knowledge_base_service_inst
|
||||
|
||||
pipeline_service_inst = pipeline_service.PipelineService(ap)
|
||||
ap.pipeline_service = pipeline_service_inst
|
||||
|
||||
bot_service_inst = bot_service.BotService(ap)
|
||||
ap.bot_service = bot_service_inst
|
||||
|
||||
knowledge_service_inst = knowledge_service.KnowledgeService(ap)
|
||||
ap.knowledge_service = knowledge_service_inst
|
||||
|
||||
ctrl = controller.Controller(ap)
|
||||
ap.ctrl = ctrl
|
||||
|
||||
@@ -1,51 +1,50 @@
|
||||
from sqlalchemy import create_engine, Column, String, Text, DateTime, LargeBinary, Integer
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from datetime import datetime
|
||||
import os
|
||||
import uuid
|
||||
import sqlalchemy
|
||||
from .base import Base
|
||||
|
||||
Base = declarative_base()
|
||||
DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./rag_knowledge.db')
|
||||
print("Using database URL:", DATABASE_URL)
|
||||
# Base = declarative_base()
|
||||
# DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./rag_knowledge.db')
|
||||
# print("Using database URL:", DATABASE_URL)
|
||||
|
||||
|
||||
engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False})
|
||||
# engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False})
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
# SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
def create_db_and_tables():
|
||||
"""Creates all database tables defined in the Base."""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
print('Database tables created or already exist.')
|
||||
# def create_db_and_tables():
|
||||
# """Creates all database tables defined in the Base."""
|
||||
# Base.metadata.create_all(bind=engine)
|
||||
# print('Database tables created or already exist.')
|
||||
|
||||
|
||||
class KnowledgeBase(Base):
|
||||
__tablename__ = 'kb'
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
name = Column(String, index=True)
|
||||
description = Column(Text)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
embedding_model_uuid = Column(String, default='')
|
||||
top_k = Column(Integer, default=5)
|
||||
__tablename__ = 'knowledge_bases'
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
name = sqlalchemy.Column(sqlalchemy.String, index=True)
|
||||
description = sqlalchemy.Column(sqlalchemy.Text)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now())
|
||||
embedding_model_uuid = sqlalchemy.Column(sqlalchemy.String, default='')
|
||||
top_k = sqlalchemy.Column(sqlalchemy.Integer, default=5)
|
||||
|
||||
|
||||
class File(Base):
|
||||
__tablename__ = 'file'
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
kb_id = Column(String, nullable=True)
|
||||
file_name = Column(String)
|
||||
path = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
file_type = Column(String)
|
||||
status = Column(Integer, default=0) # 0: uploaded and processing, 1: completed, 2: failed
|
||||
__tablename__ = 'knowledge_base_files'
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
kb_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
file_name = sqlalchemy.Column(sqlalchemy.String)
|
||||
extension = sqlalchemy.Column(sqlalchemy.String)
|
||||
created_at = sqlalchemy.Column(sqlalchemy.DateTime, default=sqlalchemy.func.now())
|
||||
status = sqlalchemy.Column(sqlalchemy.String, default='pending') # pending, processing, completed, failed
|
||||
|
||||
|
||||
class Chunk(Base):
|
||||
__tablename__ = 'chunks'
|
||||
id = Column(String, primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
file_id = Column(String, nullable=True)
|
||||
text = Column(Text)
|
||||
__tablename__ = 'knowledge_base_chunks'
|
||||
uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
file_id = sqlalchemy.Column(sqlalchemy.String(255), nullable=True)
|
||||
text = sqlalchemy.Column(sqlalchemy.Text)
|
||||
|
||||
class Vector(Base):
|
||||
__tablename__ = 'vectors'
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
chunk_id = Column(String, nullable=True)
|
||||
embedding = Column(LargeBinary)
|
||||
|
||||
# class Vector(Base):
|
||||
# __tablename__ = 'knowledge_base_vectors'
|
||||
# uuid = sqlalchemy.Column(sqlalchemy.String(255), primary_key=True, unique=True)
|
||||
# chunk_id = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
||||
# embedding = sqlalchemy.Column(sqlalchemy.LargeBinary)
|
||||
|
||||
@@ -119,7 +119,7 @@ class EventLogger:
|
||||
async def _truncate_logs(self):
|
||||
if len(self.logs) > MAX_LOG_COUNT:
|
||||
for i in range(DELETE_COUNT_PER_TIME):
|
||||
for image_key in self.logs[i].images:
|
||||
for image_key in self.logs[i].images: # type: ignore
|
||||
await self.ap.storage_mgr.storage_provider.delete(image_key)
|
||||
self.logs = self.logs[DELETE_COUNT_PER_TIME:]
|
||||
|
||||
|
||||
@@ -1,149 +1,189 @@
|
||||
|
||||
from __future__ import annotations
|
||||
import os
|
||||
import asyncio
|
||||
import traceback
|
||||
import uuid
|
||||
from pkg.rag.knowledge.services.parser import FileParser
|
||||
from pkg.rag.knowledge.services.chunker import Chunker
|
||||
from pkg.rag.knowledge.services.database import create_db_and_tables, SessionLocal, KnowledgeBase, File, Chunk
|
||||
from pkg.rag.knowledge.services.database import (
|
||||
KnowledgeBase,
|
||||
File,
|
||||
Chunk,
|
||||
)
|
||||
from pkg.core import app
|
||||
from pkg.rag.knowledge.services.embedder import Embedder
|
||||
from pkg.rag.knowledge.services.retriever import Retriever
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
|
||||
from ...entity.persistence import model as persistence_model
|
||||
from pkg.core import taskmgr
|
||||
from ...entity.persistence import rag as persistence_rag
|
||||
import sqlalchemy
|
||||
|
||||
|
||||
class RuntimeKnowledgeBase:
|
||||
ap: app.Application
|
||||
|
||||
knowledge_base_entity: persistence_rag.KnowledgeBase
|
||||
|
||||
chroma_manager: ChromaIndexManager
|
||||
|
||||
parser: FileParser
|
||||
|
||||
chunker: Chunker
|
||||
|
||||
embedder: Embedder
|
||||
|
||||
retriever: Retriever
|
||||
|
||||
def __init__(self, ap: app.Application, knowledge_base_entity: persistence_rag.KnowledgeBase):
|
||||
self.ap = ap
|
||||
self.knowledge_base_entity = knowledge_base_entity
|
||||
self.chroma_manager = ChromaIndexManager(ap=self.ap)
|
||||
self.parser = FileParser(ap=self.ap)
|
||||
self.chunker = Chunker(ap=self.ap)
|
||||
self.embedder = Embedder(ap=self.ap, chroma_manager=self.chroma_manager)
|
||||
self.retriever = Retriever(ap=self.ap, chroma_manager=self.chroma_manager)
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def _store_file_task(self, file: persistence_rag.File, task_context: taskmgr.TaskContext):
|
||||
try:
|
||||
# set file status to processing
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_rag.File)
|
||||
.where(persistence_rag.File.uuid == file.uuid)
|
||||
.values(status='processing')
|
||||
)
|
||||
|
||||
task_context.set_current_action('Parsing file')
|
||||
# parse file
|
||||
text = await self.parser.parse(file.file_name, file.extension)
|
||||
if not text:
|
||||
raise Exception(f'No text extracted from file {file.file_name}')
|
||||
|
||||
task_context.set_current_action('Chunking file')
|
||||
# chunk file
|
||||
chunks_texts = await self.chunker.chunk(text)
|
||||
if not chunks_texts:
|
||||
raise Exception(f'No chunks extracted from file {file.file_name}')
|
||||
|
||||
task_context.set_current_action('Embedding chunks')
|
||||
# embed chunks
|
||||
await self.embedder.embed_and_store(
|
||||
file_id=file.uuid, chunks=chunks_texts, embedding_model=self.knowledge_base_entity.embedding_model_uuid
|
||||
)
|
||||
|
||||
# set file status to completed
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_rag.File)
|
||||
.where(persistence_rag.File.uuid == file.uuid)
|
||||
.values(status='completed')
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Error storing file {file.file_id}: {e}')
|
||||
# set file status to failed
|
||||
await self.ap.persistence_mgr.execute_async(
|
||||
sqlalchemy.update(persistence_rag.File)
|
||||
.where(persistence_rag.File.uuid == file.uuid)
|
||||
.values(status='failed')
|
||||
)
|
||||
|
||||
raise
|
||||
|
||||
async def store_file(self, file_id: str) -> int:
|
||||
# pre checking
|
||||
if not await self.ap.storage_mgr.storage_provider.exists(file_id):
|
||||
raise Exception(f'File {file_id} not found')
|
||||
|
||||
file_uuid = str(uuid.uuid4())
|
||||
kb_id = self.knowledge_base_entity.uuid
|
||||
file_name = file_id
|
||||
extension = os.path.splitext(file_id)[1].lstrip('.')
|
||||
|
||||
file = persistence_rag.File(
|
||||
uuid=file_uuid,
|
||||
kb_id=kb_id,
|
||||
file_name=file_name,
|
||||
extension=extension,
|
||||
status='pending',
|
||||
)
|
||||
|
||||
await self.ap.persistence_mgr.execute_async(sqlalchemy.insert(persistence_rag.File).values(**file.to_dict()))
|
||||
|
||||
# run background task asynchronously
|
||||
ctx = taskmgr.TaskContext.new()
|
||||
wrapper = self.ap.task_mgr.create_user_task(
|
||||
self._store_file_task(file, task_context=ctx),
|
||||
kind='knowledge-operation',
|
||||
name=f'knowledge-store-file-{file_id}',
|
||||
label=f'Store file {file_id}',
|
||||
context=ctx,
|
||||
)
|
||||
return wrapper.id
|
||||
|
||||
async def dispose(self):
|
||||
pass
|
||||
|
||||
|
||||
class RAGManager:
|
||||
ap: app.Application
|
||||
|
||||
knowledge_bases: list[RuntimeKnowledgeBase]
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
self.chroma_manager = ChromaIndexManager()
|
||||
self.parser = FileParser()
|
||||
self.chunker = Chunker()
|
||||
self.embedder = Embedder(ap=self.ap, chroma_manager=self.chroma_manager)
|
||||
self.retriever = Retriever(ap=self.ap, chroma_manager=self.chroma_manager)
|
||||
self.knowledge_bases = []
|
||||
|
||||
async def initialize_rag_system(self):
|
||||
"""Initializes the RAG system by creating database tables."""
|
||||
await asyncio.to_thread(create_db_and_tables)
|
||||
async def initialize(self):
|
||||
pass
|
||||
|
||||
async def create_knowledge_base(
|
||||
self, kb_name: str, kb_description: str, embedding_model_uuid: str = '', top_k: int = 5
|
||||
):
|
||||
"""
|
||||
Creates a new knowledge base if it doesn't already exist.
|
||||
"""
|
||||
try:
|
||||
if not kb_name:
|
||||
raise ValueError('Knowledge base name must be set while creating.')
|
||||
async def load_knowledge_bases_from_db(self):
|
||||
self.ap.logger.info('Loading knowledge bases from db...')
|
||||
|
||||
def _create_kb_sync():
|
||||
session = SessionLocal()
|
||||
try:
|
||||
kb = session.query(KnowledgeBase).filter_by(name=kb_name).first()
|
||||
if not kb:
|
||||
id = str(uuid.uuid4())
|
||||
new_kb = KnowledgeBase(
|
||||
name=kb_name,
|
||||
description=kb_description,
|
||||
embedding_model_uuid=embedding_model_uuid,
|
||||
top_k=top_k,
|
||||
id=id,
|
||||
)
|
||||
session.add(new_kb)
|
||||
session.commit()
|
||||
session.refresh(new_kb)
|
||||
self.ap.logger.info(f"Knowledge Base '{kb_name}' created.")
|
||||
print(embedding_model_uuid)
|
||||
return new_kb.id
|
||||
else:
|
||||
self.ap.logger.info(f"Knowledge Base '{kb_name}' already exists.")
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
self.ap.logger.error(f"Error in _create_kb_sync for '{kb_name}': {str(e)}", exc_info=True)
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
self.knowledge_bases = []
|
||||
|
||||
return await asyncio.to_thread(_create_kb_sync)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f"Error creating knowledge base '{kb_name}': {str(e)}", exc_info=True)
|
||||
raise
|
||||
result = await self.ap.persistence_mgr.execute_async(sqlalchemy.select(persistence_rag.KnowledgeBase))
|
||||
|
||||
async def get_all_knowledge_bases(self):
|
||||
"""
|
||||
Retrieves all knowledge bases from the database.
|
||||
"""
|
||||
try:
|
||||
knowledge_bases = result.all()
|
||||
|
||||
def _get_all_kbs_sync():
|
||||
session = SessionLocal()
|
||||
try:
|
||||
return session.query(KnowledgeBase).all()
|
||||
finally:
|
||||
session.close()
|
||||
for knowledge_base in knowledge_bases:
|
||||
try:
|
||||
await self.load_knowledge_base(knowledge_base)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(
|
||||
f'Error loading knowledge base {knowledge_base.uuid}: {e}\n{traceback.format_exc()}'
|
||||
)
|
||||
|
||||
return await asyncio.to_thread(_get_all_kbs_sync)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Error retrieving knowledge bases: {str(e)}', exc_info=True)
|
||||
return []
|
||||
async def load_knowledge_base(
|
||||
self,
|
||||
knowledge_base_entity: persistence_rag.KnowledgeBase | sqlalchemy.Row | dict,
|
||||
) -> RuntimeKnowledgeBase:
|
||||
if isinstance(knowledge_base_entity, sqlalchemy.Row):
|
||||
knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity._mapping)
|
||||
elif isinstance(knowledge_base_entity, dict):
|
||||
knowledge_base_entity = persistence_rag.KnowledgeBase(**knowledge_base_entity)
|
||||
|
||||
async def get_knowledge_base_by_id(self, kb_id: str):
|
||||
"""
|
||||
Retrieves a specific knowledge base by its ID.
|
||||
"""
|
||||
try:
|
||||
runtime_knowledge_base = RuntimeKnowledgeBase(ap=self.ap, knowledge_base_entity=knowledge_base_entity)
|
||||
|
||||
def _get_kb_sync(kb_id_param):
|
||||
session = SessionLocal()
|
||||
try:
|
||||
return session.query(KnowledgeBase).filter_by(id=kb_id_param).first()
|
||||
finally:
|
||||
session.close()
|
||||
await runtime_knowledge_base.initialize()
|
||||
|
||||
return await asyncio.to_thread(_get_kb_sync, kb_id)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Error retrieving knowledge base with ID {kb_id}: {str(e)}', exc_info=True)
|
||||
return None
|
||||
self.knowledge_bases.append(runtime_knowledge_base)
|
||||
|
||||
async def get_files_by_knowledge_base(self, kb_id: str):
|
||||
"""
|
||||
Retrieves files associated with a specific knowledge base by querying the File table directly.
|
||||
"""
|
||||
try:
|
||||
return runtime_knowledge_base
|
||||
|
||||
def _get_files_sync(kb_id_param):
|
||||
session = SessionLocal()
|
||||
try:
|
||||
return session.query(File).filter_by(kb_id=kb_id_param).all()
|
||||
finally:
|
||||
session.close()
|
||||
async def get_knowledge_base_by_uuid(self, kb_uuid: str) -> RuntimeKnowledgeBase | None:
|
||||
for kb in self.knowledge_bases:
|
||||
if kb.knowledge_base_entity.uuid == kb_uuid:
|
||||
return kb
|
||||
return None
|
||||
|
||||
return await asyncio.to_thread(_get_files_sync, kb_id)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Error retrieving files for knowledge base ID {kb_id}: {str(e)}', exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_all_files(self):
|
||||
"""
|
||||
Retrieves all files stored in the database, regardless of their association
|
||||
with any specific knowledge base.
|
||||
"""
|
||||
try:
|
||||
|
||||
def _get_all_files_sync():
|
||||
session = SessionLocal()
|
||||
try:
|
||||
return session.query(File).all()
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
return await asyncio.to_thread(_get_all_files_sync)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(f'Error retrieving all files: {str(e)}', exc_info=True)
|
||||
return []
|
||||
async def remove_knowledge_base(self, kb_uuid: str):
|
||||
for kb in self.knowledge_bases:
|
||||
if kb.knowledge_base_entity.uuid == kb_uuid:
|
||||
await kb.dispose()
|
||||
self.knowledge_bases.remove(kb)
|
||||
return
|
||||
|
||||
async def store_data(self, file_path: str, kb_id: str, file_type: str, file_id: str = None):
|
||||
"""
|
||||
@@ -220,7 +260,8 @@ class RAGManager:
|
||||
await self.ap.storage_mgr.storage_provider.delete(file_id)
|
||||
except Exception as e:
|
||||
self.ap.logger.error(
|
||||
f'Error deleting file from storage for file_id {file_id}: {str(e)}', exc_info=True
|
||||
f'Error deleting file from storage for file_id {file_id}: {str(e)}',
|
||||
exc_info=True,
|
||||
)
|
||||
self.ap.logger.info(f'Deleted file record for file_id: {file_id}')
|
||||
else:
|
||||
@@ -273,7 +314,10 @@ class RAGManager:
|
||||
)
|
||||
except Exception as kb_del_e:
|
||||
session.rollback()
|
||||
self.ap.logger.error(f'Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}', exc_info=True)
|
||||
self.ap.logger.error(
|
||||
f'Error deleting KnowledgeBase record for ID {kb_id}: {kb_del_e}',
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
@@ -283,7 +327,8 @@ class RAGManager:
|
||||
if session.is_active:
|
||||
session.rollback()
|
||||
self.ap.logger.error(
|
||||
f'Error during overall knowledge base deletion for ID {kb_id}: {str(e)}', exc_info=True
|
||||
f'Error during overall knowledge base deletion for ID {kb_id}: {str(e)}',
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
|
||||
@@ -1,43 +1,43 @@
|
||||
|
||||
import numpy as np
|
||||
import logging
|
||||
from chromadb import PersistentClient
|
||||
import os
|
||||
from pkg.core import app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChromaIndexManager:
|
||||
def __init__(self, collection_name: str = "default_collection"):
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
chroma_data_path = os.path.abspath(os.path.join(__file__, "../../../../../../data/chroma"))
|
||||
os.makedirs(chroma_data_path, exist_ok=True)
|
||||
def __init__(self, ap: app.Application, collection_name: str = 'default_collection'):
|
||||
self.ap = ap
|
||||
chroma_data_path = './data/chroma'
|
||||
self.client = PersistentClient(path=chroma_data_path)
|
||||
self._collection_name = collection_name
|
||||
self._collection = None
|
||||
|
||||
self.logger.info(f"ChromaIndexManager initialized. Collection name: {self._collection_name}")
|
||||
self.ap.logger.info(f'ChromaIndexManager initialized. Collection name: {self._collection_name}')
|
||||
|
||||
@property
|
||||
def collection(self):
|
||||
if self._collection is None:
|
||||
self._collection = self.client.get_or_create_collection(name=self._collection_name)
|
||||
self.logger.info(f"Chroma collection '{self._collection_name}' accessed/created.")
|
||||
self.ap.logger.info(f"Chroma collection '{self._collection_name}' accessed/created.")
|
||||
return self._collection
|
||||
|
||||
def add_embeddings_sync(self, file_ids: list[int], chunk_ids: list[int], embeddings: np.ndarray, documents: list[str]):
|
||||
if embeddings.shape[0] != len(chunk_ids) or embeddings.shape[0] != len(file_ids) or embeddings.shape[0] != len(documents):
|
||||
raise ValueError("Embedding, file_id, chunk_id, and document count mismatch.")
|
||||
def add_embeddings_sync(
|
||||
self, file_ids: list[int], chunk_ids: list[int], embeddings: np.ndarray, documents: list[str]
|
||||
):
|
||||
if (
|
||||
embeddings.shape[0] != len(chunk_ids)
|
||||
or embeddings.shape[0] != len(file_ids)
|
||||
or embeddings.shape[0] != len(documents)
|
||||
):
|
||||
raise ValueError('Embedding, file_id, chunk_id, and document count mismatch.')
|
||||
|
||||
chroma_ids = [f"{file_id}_{chunk_id}" for file_id, chunk_id in zip(file_ids, chunk_ids)]
|
||||
metadatas = [{"file_id": fid, "chunk_id": cid} for fid, cid in zip(file_ids, chunk_ids)]
|
||||
chroma_ids = [f'{file_id}_{chunk_id}' for file_id, chunk_id in zip(file_ids, chunk_ids)]
|
||||
metadatas = [{'file_id': fid, 'chunk_id': cid} for fid, cid in zip(file_ids, chunk_ids)]
|
||||
|
||||
self.logger.debug(f"Adding {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.")
|
||||
self.collection.add(
|
||||
embeddings=embeddings.tolist(),
|
||||
ids=chroma_ids,
|
||||
metadatas=metadatas,
|
||||
documents=documents
|
||||
)
|
||||
self.collection.add(embeddings=embeddings.tolist(), ids=chroma_ids, metadatas=metadatas, documents=documents)
|
||||
self.logger.info(f"Added {len(embeddings)} embeddings to Chroma collection '{self._collection_name}'.")
|
||||
|
||||
def search_sync(self, query_embedding: np.ndarray, k: int = 5):
|
||||
@@ -54,12 +54,14 @@ class ChromaIndexManager:
|
||||
query_embeddings=query_embedding.tolist(),
|
||||
n_results=k,
|
||||
# REMOVE 'ids' from the include list. It's returned by default.
|
||||
include=["metadatas", "distances", "documents"]
|
||||
include=['metadatas', 'distances', 'documents'],
|
||||
)
|
||||
self.logger.debug(f"Chroma search returned {len(results.get('ids', [[]])[0])} results.")
|
||||
self.logger.debug(f'Chroma search returned {len(results.get("ids", [[]])[0])} results.')
|
||||
return results
|
||||
|
||||
def delete_by_file_id_sync(self, file_id: int):
|
||||
self.logger.info(f"Deleting embeddings for file_id: {file_id} from Chroma collection '{self._collection_name}'.")
|
||||
self.collection.delete(where={"file_id": file_id})
|
||||
self.logger.info(f"Deleted embeddings for file_id: {file_id} from Chroma.")
|
||||
self.logger.info(
|
||||
f"Deleting embeddings for file_id: {file_id} from Chroma collection '{self._collection_name}'."
|
||||
)
|
||||
self.collection.delete(where={'file_id': file_id})
|
||||
self.logger.info(f'Deleted embeddings for file_id: {file_id} from Chroma.')
|
||||
|
||||
@@ -1,21 +1,26 @@
|
||||
# services/chunker.py
|
||||
import logging
|
||||
from typing import List
|
||||
from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync
|
||||
from pkg.rag.knowledge.services.base_service import BaseService # Assuming BaseService provides _run_sync
|
||||
from pkg.core import app
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Chunker(BaseService):
|
||||
"""
|
||||
A class for splitting long texts into smaller, overlapping chunks.
|
||||
"""
|
||||
def __init__(self, chunk_size: int = 500, chunk_overlap: int = 50):
|
||||
super().__init__() # Initialize BaseService
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
def __init__(self, ap: app.Application, chunk_size: int = 500, chunk_overlap: int = 50):
|
||||
super().__init__(ap) # Initialize BaseService
|
||||
self.ap = ap
|
||||
self.chunk_size = chunk_size
|
||||
self.chunk_overlap = chunk_overlap
|
||||
if self.chunk_overlap >= self.chunk_size:
|
||||
self.logger.warning("Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.")
|
||||
self.logger.warning(
|
||||
'Chunk overlap is greater than or equal to chunk size. This may lead to empty or malformed chunks.'
|
||||
)
|
||||
|
||||
def _split_text_sync(self, text: str) -> List[str]:
|
||||
"""
|
||||
@@ -27,18 +32,19 @@ class Chunker(BaseService):
|
||||
# words = text.split()
|
||||
# chunks = []
|
||||
# current_chunk = []
|
||||
|
||||
|
||||
# for word in words:
|
||||
# current_chunk.append(word)
|
||||
# if len(current_chunk) > self.chunk_size:
|
||||
# chunks.append(" ".join(current_chunk[:self.chunk_size]))
|
||||
# current_chunk = current_chunk[self.chunk_size - self.chunk_overlap:]
|
||||
|
||||
|
||||
# if current_chunk:
|
||||
# chunks.append(" ".join(current_chunk))
|
||||
|
||||
|
||||
# A more robust chunking strategy (e.g., using recursive character text splitter)
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=self.chunk_size,
|
||||
chunk_overlap=self.chunk_overlap,
|
||||
@@ -51,8 +57,8 @@ class Chunker(BaseService):
|
||||
"""
|
||||
Asynchronously chunks a given text into smaller pieces.
|
||||
"""
|
||||
self.logger.info(f"Chunking text (length: {len(text)})...")
|
||||
self.ap.logger.info(f'Chunking text (length: {len(text)})...')
|
||||
# Run the synchronous splitting logic in a separate thread
|
||||
chunks = await self._run_sync(self._split_text_sync, text)
|
||||
self.logger.info(f"Text chunked into {len(chunks)} pieces.")
|
||||
return chunks
|
||||
self.ap.logger.info(f'Text chunked into {len(chunks)} pieces.')
|
||||
return chunks
|
||||
|
||||
@@ -7,16 +7,10 @@ from sqlalchemy.orm import Session
|
||||
from pkg.rag.knowledge.services.base_service import BaseService
|
||||
from pkg.rag.knowledge.services.database import Chunk, SessionLocal
|
||||
from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from ....core import app
|
||||
from ....entity.persistence import model as persistence_model
|
||||
import sqlalchemy
|
||||
from ....provider.modelmgr.requester import RuntimeEmbeddingModel
|
||||
|
||||
|
||||
base = declarative_base()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class Embedder(BaseService):
|
||||
def __init__(self, ap: app.Application, chroma_manager: ChromaIndexManager = None) -> None:
|
||||
super().__init__()
|
||||
@@ -30,61 +24,66 @@ class Embedder(BaseService):
|
||||
This function assumes it's called within a context where the session
|
||||
will be committed/rolled back and closed by the caller.
|
||||
"""
|
||||
self.logger.debug(f"Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).")
|
||||
self.logger.debug(f'Saving {len(chunks_texts)} chunks for file_id {file_id} to DB (sync).')
|
||||
chunk_objects = []
|
||||
for text in chunks_texts:
|
||||
chunk = Chunk(file_id=file_id, text=text)
|
||||
session.add(chunk)
|
||||
chunk_objects.append(chunk)
|
||||
session.flush() # This populates the .id attribute for each new chunk object
|
||||
self.logger.debug(f"Successfully added {len(chunk_objects)} chunk entries to DB.")
|
||||
session.flush() # This populates the .id attribute for each new chunk object
|
||||
self.logger.debug(f'Successfully added {len(chunk_objects)} chunk entries to DB.')
|
||||
return chunk_objects
|
||||
|
||||
async def embed_and_store(self, file_id: int, chunks: List[str], embedding_model: RuntimeEmbeddingModel) -> List[Chunk]:
|
||||
async def embed_and_store(
|
||||
self, file_id: int, chunks: List[str], embedding_model: RuntimeEmbeddingModel
|
||||
) -> List[Chunk]:
|
||||
if not embedding_model:
|
||||
raise RuntimeError("Embedding model not loaded. Please check Embedder initialization.")
|
||||
raise RuntimeError('Embedding model not loaded. Please check Embedder initialization.')
|
||||
|
||||
session = SessionLocal() # Start a session that will live for the whole operation
|
||||
session = SessionLocal() # Start a session that will live for the whole operation
|
||||
chunk_objects = []
|
||||
try:
|
||||
# 1. Save chunks to the relational database first to get their IDs
|
||||
# We call _db_save_chunks_sync directly without _run_sync's session management
|
||||
# because we manage the session here across multiple async calls.
|
||||
chunk_objects = await asyncio.to_thread(self._db_save_chunks_sync, session, file_id, chunks)
|
||||
session.commit() # Commit chunks to make their IDs permanent and accessible
|
||||
session.commit() # Commit chunks to make their IDs permanent and accessible
|
||||
|
||||
if not chunk_objects:
|
||||
self.logger.warning(f"No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.")
|
||||
self.logger.warning(
|
||||
f'No chunk objects created for file_id {file_id}. Skipping embedding and Chroma storage.'
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
# get the embeddings for the chunks
|
||||
embeddings = []
|
||||
i = 0
|
||||
while i <len(chunks):
|
||||
chunk = chunks[i]
|
||||
embeddings: list[list[float]] = []
|
||||
|
||||
for chunk in chunks:
|
||||
result = await embedding_model.requester.invoke_embedding(
|
||||
model=embedding_model,
|
||||
input_text=chunk,
|
||||
)
|
||||
embeddings.append(result)
|
||||
i += 1
|
||||
|
||||
|
||||
embeddings_np = np.array(embeddings, dtype=np.float32)
|
||||
|
||||
self.logger.info("Saving embeddings to Chroma...")
|
||||
chunk_ids = [c.id for c in chunk_objects]
|
||||
self.logger.info('Saving embeddings to Chroma...')
|
||||
chunk_ids = [c.id for c in chunk_objects]
|
||||
file_ids_for_chroma = [file_id] * len(chunk_ids)
|
||||
|
||||
await self._run_sync( # Use _run_sync for the Chroma operation, as it's a sync call
|
||||
await self._run_sync( # Use _run_sync for the Chroma operation, as it's a sync call
|
||||
self.chroma_manager.add_embeddings_sync,
|
||||
file_ids_for_chroma, chunk_ids, embeddings_np, chunks # Pass original chunks texts for documents
|
||||
file_ids_for_chroma,
|
||||
chunk_ids,
|
||||
embeddings_np,
|
||||
chunks, # Pass original chunks texts for documents
|
||||
)
|
||||
self.logger.info(f"Successfully saved {len(chunk_objects)} embeddings to Chroma.")
|
||||
self.logger.info(f'Successfully saved {len(chunk_objects)} embeddings to Chroma.')
|
||||
return chunk_objects
|
||||
|
||||
except Exception as e:
|
||||
session.rollback() # Rollback on any error
|
||||
self.logger.error(f"Failed to process and store data for file_id {file_id}: {e}", exc_info=True)
|
||||
raise # Re-raise the exception to propagate it
|
||||
session.rollback() # Rollback on any error
|
||||
self.logger.error(f'Failed to process and store data for file_id {file_id}: {e}', exc_info=True)
|
||||
raise # Re-raise the exception to propagate it
|
||||
finally:
|
||||
session.close() # Ensure the session is always closed
|
||||
session.close() # Ensure the session is always closed
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import PyPDF2
|
||||
import io
|
||||
from docx import Document
|
||||
import pandas as pd
|
||||
import chardet
|
||||
@@ -10,7 +11,7 @@ import ebooklib
|
||||
from ebooklib import epub
|
||||
import re
|
||||
import asyncio # Import asyncio for async operations
|
||||
import os
|
||||
from pkg.core import app
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -24,8 +25,8 @@ class FileParser:
|
||||
to avoid blocking the asyncio event loop.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def _run_sync(self, sync_func: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||
"""
|
||||
@@ -35,138 +36,160 @@ class FileParser:
|
||||
try:
|
||||
return await asyncio.to_thread(sync_func, *args, **kwargs)
|
||||
except Exception as e:
|
||||
self.logger.error(f'Error running synchronous function {sync_func.__name__}: {e}')
|
||||
self.ap.logger.error(f'Error running synchronous function {sync_func.__name__}: {e}')
|
||||
raise
|
||||
|
||||
async def parse(self, file_path: str) -> Union[str, None]:
|
||||
async def parse(self, file_name: str, extension: str) -> Union[str, None]:
|
||||
"""
|
||||
Parses the file based on its extension and returns the extracted text content.
|
||||
This is the main asynchronous entry point for parsing.
|
||||
|
||||
Args:
|
||||
file_path (str): The path to the file to be parsed.
|
||||
file_name (str): The name of the file to be parsed, get from ap.storage_mgr
|
||||
|
||||
Returns:
|
||||
Union[str, None]: The extracted text content as a single string, or None if parsing fails.
|
||||
"""
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
self.logger.error(f'Invalid file path provided: {file_path}')
|
||||
return None
|
||||
|
||||
file_extension = file_path.split('.')[-1].lower()
|
||||
file_extension = extension.lower()
|
||||
parser_method = getattr(self, f'_parse_{file_extension}', None)
|
||||
|
||||
if parser_method is None:
|
||||
self.logger.error(f'Unsupported file format: {file_extension} for file {file_path}')
|
||||
self.ap.logger.error(f'Unsupported file format: {file_extension} for file {file_name}')
|
||||
return None
|
||||
|
||||
try:
|
||||
# Pass file_path to the specific parser methods
|
||||
return await parser_method(file_path)
|
||||
return await parser_method(file_name)
|
||||
except Exception as e:
|
||||
self.logger.error(f'Failed to parse {file_extension} file {file_path}: {e}')
|
||||
self.ap.logger.error(f'Failed to parse {file_extension} file {file_name}: {e}')
|
||||
return None
|
||||
|
||||
# --- Helper for reading files with encoding detection ---
|
||||
async def _read_file_content(self, file_path: str, mode: str = 'r') -> Union[str, bytes]:
|
||||
async def _read_file_content(self, file_name: str) -> Union[str, bytes]:
|
||||
"""
|
||||
Reads a file with automatic encoding detection, ensuring the synchronous
|
||||
file read operation runs in a separate thread.
|
||||
"""
|
||||
|
||||
def _read_sync():
|
||||
with open(file_path, 'rb') as file:
|
||||
raw_data = file.read()
|
||||
detected = chardet.detect(raw_data)
|
||||
encoding = detected['encoding'] or 'utf-8'
|
||||
# def _read_sync():
|
||||
# with open(file_path, 'rb') as file:
|
||||
# raw_data = file.read()
|
||||
# detected = chardet.detect(raw_data)
|
||||
# encoding = detected['encoding'] or 'utf-8'
|
||||
|
||||
if mode == 'r':
|
||||
return raw_data.decode(encoding, errors='ignore')
|
||||
return raw_data # For binary mode
|
||||
# if mode == 'r':
|
||||
# return raw_data.decode(encoding, errors='ignore')
|
||||
# return raw_data # For binary mode
|
||||
|
||||
return await self._run_sync(_read_sync)
|
||||
# return await self._run_sync(_read_sync)
|
||||
file_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
detected = chardet.detect(file_bytes)
|
||||
encoding = detected['encoding'] or 'utf-8'
|
||||
|
||||
return file_bytes.decode(encoding, errors='ignore')
|
||||
|
||||
# --- Specific Parser Methods ---
|
||||
|
||||
async def _parse_txt(self, file_path: str) -> str:
|
||||
async def _parse_txt(self, file_name: str) -> str:
|
||||
"""Parses a TXT file and returns its content."""
|
||||
self.logger.info(f'Parsing TXT file: {file_path}')
|
||||
return await self._read_file_content(file_path, mode='r')
|
||||
self.ap.logger.info(f'Parsing TXT file: {file_name}')
|
||||
return await self._read_file_content(file_name)
|
||||
|
||||
async def _parse_pdf(self, file_path: str) -> str:
|
||||
async def _parse_pdf(self, file_name: str) -> str:
|
||||
"""Parses a PDF file and returns its text content."""
|
||||
self.logger.info(f'Parsing PDF file: {file_path}')
|
||||
self.ap.logger.info(f'Parsing PDF file: {file_name}')
|
||||
|
||||
# def _parse_pdf_sync():
|
||||
# text_content = []
|
||||
# with open(file_name, 'rb') as file:
|
||||
# pdf_reader = PyPDF2.PdfReader(file)
|
||||
# for page in pdf_reader.pages:
|
||||
# text = page.extract_text()
|
||||
# if text:
|
||||
# text_content.append(text)
|
||||
# return '\n'.join(text_content)
|
||||
|
||||
# return await self._run_sync(_parse_pdf_sync)
|
||||
|
||||
pdf_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
def _parse_pdf_sync():
|
||||
pdf_reader = PyPDF2.PdfReader(io.BytesIO(pdf_bytes))
|
||||
text_content = []
|
||||
with open(file_path, 'rb') as file:
|
||||
pdf_reader = PyPDF2.PdfReader(file)
|
||||
for page in pdf_reader.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
text_content.append(text)
|
||||
for page in pdf_reader.pages:
|
||||
text = page.extract_text()
|
||||
if text:
|
||||
text_content.append(text)
|
||||
return '\n'.join(text_content)
|
||||
|
||||
return await self._run_sync(_parse_pdf_sync)
|
||||
|
||||
async def _parse_docx(self, file_path: str) -> str:
|
||||
async def _parse_docx(self, file_name: str) -> str:
|
||||
"""Parses a DOCX file and returns its text content."""
|
||||
self.logger.info(f'Parsing DOCX file: {file_path}')
|
||||
self.ap.logger.info(f'Parsing DOCX file: {file_name}')
|
||||
|
||||
docx_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
def _parse_docx_sync():
|
||||
doc = Document(file_path)
|
||||
doc = Document(io.BytesIO(docx_bytes))
|
||||
text_content = [paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()]
|
||||
return '\n'.join(text_content)
|
||||
|
||||
return await self._run_sync(_parse_docx_sync)
|
||||
|
||||
async def _parse_doc(self, file_path: str) -> str:
|
||||
async def _parse_doc(self, file_name: str) -> str:
|
||||
"""Handles .doc files, explicitly stating lack of direct support."""
|
||||
self.logger.warning(f'Direct .doc parsing is not supported for {file_path}. Please convert to .docx first.')
|
||||
self.ap.logger.warning(f'Direct .doc parsing is not supported for {file_name}. Please convert to .docx first.')
|
||||
raise NotImplementedError('Direct .doc parsing not supported. Please convert to .docx first.')
|
||||
|
||||
async def _parse_xlsx(self, file_path: str) -> str:
|
||||
async def _parse_xlsx(self, file_name: str) -> str:
|
||||
"""Parses an XLSX file, returning text from all sheets."""
|
||||
self.logger.info(f'Parsing XLSX file: {file_path}')
|
||||
self.ap.logger.info(f'Parsing XLSX file: {file_name}')
|
||||
|
||||
xlsx_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
def _parse_xlsx_sync():
|
||||
excel_file = pd.ExcelFile(file_path)
|
||||
excel_file = pd.ExcelFile(io.BytesIO(xlsx_bytes))
|
||||
all_sheet_content = []
|
||||
for sheet_name in excel_file.sheet_names:
|
||||
df = pd.read_excel(file_path, sheet_name=sheet_name)
|
||||
df = pd.read_excel(io.BytesIO(xlsx_bytes), sheet_name=sheet_name)
|
||||
sheet_text = f'--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n'
|
||||
all_sheet_content.append(sheet_text)
|
||||
return '\n'.join(all_sheet_content)
|
||||
|
||||
return await self._run_sync(_parse_xlsx_sync)
|
||||
|
||||
async def _parse_csv(self, file_path: str) -> str:
|
||||
async def _parse_csv(self, file_name: str) -> str:
|
||||
"""Parses a CSV file and returns its content as a string."""
|
||||
self.logger.info(f'Parsing CSV file: {file_path}')
|
||||
self.ap.logger.info(f'Parsing CSV file: {file_name}')
|
||||
|
||||
csv_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
def _parse_csv_sync():
|
||||
# pd.read_csv can often detect encoding, but explicit detection is safer
|
||||
raw_data = self._read_file_content(
|
||||
file_path, mode='rb'
|
||||
) # Note: this will need to be await outside this sync function
|
||||
_ = raw_data
|
||||
# raw_data = self._read_file_content(
|
||||
# file_name, mode='rb'
|
||||
# ) # Note: this will need to be await outside this sync function
|
||||
# _ = raw_data
|
||||
# For simplicity, we'll let pandas handle encoding internally after a raw read.
|
||||
# A more robust solution might pass encoding directly to pd.read_csv after detection.
|
||||
detected = chardet.detect(open(file_path, 'rb').read())
|
||||
detected = chardet.detect(io.BytesIO(csv_bytes))
|
||||
encoding = detected['encoding'] or 'utf-8'
|
||||
df = pd.read_csv(file_path, encoding=encoding)
|
||||
df = pd.read_csv(io.BytesIO(csv_bytes), encoding=encoding)
|
||||
return df.to_string(index=False)
|
||||
|
||||
return await self._run_sync(_parse_csv_sync)
|
||||
|
||||
async def _parse_markdown(self, file_path: str) -> str:
|
||||
async def _parse_markdown(self, file_name: str) -> str:
|
||||
"""Parses a Markdown file, converting it to structured plain text."""
|
||||
self.logger.info(f'Parsing Markdown file: {file_path}')
|
||||
self.ap.logger.info(f'Parsing Markdown file: {file_name}')
|
||||
|
||||
md_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
def _parse_markdown_sync():
|
||||
md_content = self._read_file_content(
|
||||
file_path, mode='r'
|
||||
) # This is a synchronous call within a sync function
|
||||
md_content = io.BytesIO(md_bytes).read().decode('utf-8', errors='ignore')
|
||||
html_content = markdown.markdown(
|
||||
md_content, extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code']
|
||||
)
|
||||
@@ -200,12 +223,14 @@ class FileParser:
|
||||
|
||||
return await self._run_sync(_parse_markdown_sync)
|
||||
|
||||
async def _parse_html(self, file_path: str) -> str:
|
||||
async def _parse_html(self, file_name: str) -> str:
|
||||
"""Parses an HTML file, extracting structured plain text."""
|
||||
self.logger.info(f'Parsing HTML file: {file_path}')
|
||||
self.ap.logger.info(f'Parsing HTML file: {file_name}')
|
||||
|
||||
html_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
def _parse_html_sync():
|
||||
html_content = self._read_file_content(file_path, mode='r') # Sync call within sync function
|
||||
html_content = io.BytesIO(html_bytes).read().decode('utf-8', errors='ignore')
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
for script_or_style in soup(['script', 'style']):
|
||||
script_or_style.decompose()
|
||||
@@ -236,12 +261,14 @@ class FileParser:
|
||||
|
||||
return await self._run_sync(_parse_html_sync)
|
||||
|
||||
async def _parse_epub(self, file_path: str) -> str:
|
||||
async def _parse_epub(self, file_name: str) -> str:
|
||||
"""Parses an EPUB file, extracting metadata and content."""
|
||||
self.logger.info(f'Parsing EPUB file: {file_path}')
|
||||
self.ap.logger.info(f'Parsing EPUB file: {file_name}')
|
||||
|
||||
epub_bytes = await self.ap.storage_mgr.storage_provider.load(file_name)
|
||||
|
||||
def _parse_epub_sync():
|
||||
book = epub.read_epub(file_path)
|
||||
book = epub.read_epub(io.BytesIO(epub_bytes))
|
||||
text_content = []
|
||||
title_meta = book.get_metadata('DC', 'title')
|
||||
if title_meta:
|
||||
|
||||
0
pkg/vector/__init__.py
Normal file
0
pkg/vector/__init__.py
Normal file
13
pkg/vector/mgr.py
Normal file
13
pkg/vector/mgr.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from ..core import app
|
||||
|
||||
|
||||
class VectorDBManager:
|
||||
ap: app.Application
|
||||
|
||||
def __init__(self, ap: app.Application):
|
||||
self.ap = ap
|
||||
|
||||
async def initialize(self):
|
||||
pass
|
||||
7
pkg/vector/vdb.py
Normal file
7
pkg/vector/vdb.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
|
||||
|
||||
class VectorDatabase(abc.ABC):
|
||||
pass
|
||||
Reference in New Issue
Block a user