mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 11:29:39 +08:00
fix: create knwoledge base issue
This commit is contained in:
@@ -1,19 +1,17 @@
|
||||
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, LargeBinary
|
||||
from sqlalchemy import create_engine, Column, String, Text, DateTime, LargeBinary, Integer
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker
|
||||
from datetime import datetime
|
||||
import os
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
DATABASE_URL = os.getenv('DATABASE_URL', 'sqlite:///./rag_knowledge.db')
|
||||
print("Using database URL:", DATABASE_URL)
|
||||
|
||||
|
||||
engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False})
|
||||
|
||||
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
def create_db_and_tables():
|
||||
"""Creates all database tables defined in the Base."""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
@@ -22,35 +20,31 @@ def create_db_and_tables():
|
||||
|
||||
class KnowledgeBase(Base):
|
||||
__tablename__ = 'kb'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
name = Column(String, index=True)
|
||||
description = Column(Text)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
embedding_model_uuid = Column(String, default='')
|
||||
top_k = Column(Integer, default=5)
|
||||
|
||||
|
||||
class File(Base):
|
||||
__tablename__ = 'file'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
kb_id = Column(Integer, nullable=True)
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
kb_id = Column(String, nullable=True)
|
||||
file_name = Column(String)
|
||||
path = Column(String)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
file_type = Column(String)
|
||||
status = Column(Integer, default=0)
|
||||
|
||||
status = Column(String, default='0')
|
||||
|
||||
class Chunk(Base):
|
||||
__tablename__ = 'chunks'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
file_id = Column(Integer, nullable=True)
|
||||
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
file_id = Column(String, nullable=True)
|
||||
text = Column(Text)
|
||||
|
||||
|
||||
class Vector(Base):
|
||||
__tablename__ = 'vectors'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
chunk_id = Column(Integer, nullable=True)
|
||||
id = Column(String, primary_key=True, index=True)
|
||||
chunk_id = Column(String, nullable=True)
|
||||
embedding = Column(LargeBinary)
|
||||
|
||||
@@ -41,7 +41,7 @@ class RAGManager:
|
||||
try:
|
||||
kb = session.query(KnowledgeBase).filter_by(name=kb_name).first()
|
||||
if not kb:
|
||||
id = uuid.uuid4().int
|
||||
id = str(uuid.uuid4())
|
||||
new_kb = KnowledgeBase(
|
||||
name=kb_name,
|
||||
description=kb_description,
|
||||
@@ -86,7 +86,7 @@ class RAGManager:
|
||||
self.logger.error(f'Error retrieving knowledge bases: {str(e)}', exc_info=True)
|
||||
return []
|
||||
|
||||
async def get_knowledge_base_by_id(self, kb_id: int):
|
||||
async def get_knowledge_base_by_id(self, kb_id: str):
|
||||
"""
|
||||
Retrieves a specific knowledge base by its ID.
|
||||
"""
|
||||
@@ -104,7 +104,7 @@ class RAGManager:
|
||||
self.logger.error(f'Error retrieving knowledge base with ID {kb_id}: {str(e)}', exc_info=True)
|
||||
return None
|
||||
|
||||
async def get_files_by_knowledge_base(self, kb_id: int):
|
||||
async def get_files_by_knowledge_base(self, kb_id: str):
|
||||
"""
|
||||
Retrieves files associated with a specific knowledge base by querying the File table directly.
|
||||
"""
|
||||
@@ -153,7 +153,7 @@ class RAGManager:
|
||||
file_obj = None
|
||||
|
||||
try:
|
||||
# 1. 确保知识库存在或创建它
|
||||
|
||||
kb = session.query(KnowledgeBase).filter_by(name=kb_name).first()
|
||||
if not kb:
|
||||
kb = KnowledgeBase(name=kb_name, description=kb_description)
|
||||
@@ -164,7 +164,7 @@ class RAGManager:
|
||||
else:
|
||||
self.logger.info(f"Knowledge Base '{kb_name}' already exists.")
|
||||
|
||||
# 2. 添加文件记录到数据库,并直接关联 kb_id
|
||||
|
||||
file_name = os.path.basename(file_path)
|
||||
existing_file = session.query(File).filter_by(kb_id=kb.id, file_name=file_name).first()
|
||||
if existing_file:
|
||||
@@ -181,15 +181,15 @@ class RAGManager:
|
||||
f"File record '{file_name}' added to database with ID: {file_obj.id}, associated with KB ID: {kb.id}"
|
||||
)
|
||||
|
||||
# 3. 解析文件内容
|
||||
|
||||
text = await self.parser.parse(file_path)
|
||||
if not text:
|
||||
self.logger.warning(f'No text extracted from file {file_path}. Deleting file record ID: {file_obj.id}.')
|
||||
session.delete(file_obj)
|
||||
session.commit() # 提交删除操作
|
||||
session.commit()
|
||||
return
|
||||
|
||||
# 4. 分块并嵌入/存储块
|
||||
|
||||
chunks_texts = await self.chunker.chunk(text)
|
||||
self.logger.info(f"Chunked file '{file_name}' into {len(chunks_texts)} chunks.")
|
||||
await self.embedder.embed_and_store(file_id=file_obj.id, chunks=chunks_texts)
|
||||
@@ -222,7 +222,7 @@ class RAGManager:
|
||||
self.logger.error(f"Error in retrieve_data for query '{query}': {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
async def delete_data_by_file_id(self, file_id: int):
|
||||
async def delete_data_by_file_id(self, file_id: str):
|
||||
"""
|
||||
Deletes all data associated with a specific file ID, including its chunks and vectors,
|
||||
and the file record itself.
|
||||
@@ -257,13 +257,13 @@ class RAGManager:
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
async def delete_kb_by_id(self, kb_id: int):
|
||||
async def delete_kb_by_id(self, kb_id: str):
|
||||
"""
|
||||
Deletes a knowledge base and all associated files, chunks, and vectors.
|
||||
This involves querying for associated files and then deleting them.
|
||||
"""
|
||||
self.logger.info(f'Starting deletion of knowledge base with ID: {kb_id}')
|
||||
session = SessionLocal() # 使用新的会话来获取 KB 和关联文件
|
||||
session = SessionLocal()
|
||||
|
||||
try:
|
||||
kb_to_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first()
|
||||
@@ -271,24 +271,24 @@ class RAGManager:
|
||||
self.logger.warning(f'Knowledge Base with ID {kb_id} not found.')
|
||||
return
|
||||
|
||||
# 获取所有关联的文件,通过 File 表的 kb_id 字段查询
|
||||
|
||||
files_to_delete = session.query(File).filter_by(kb_id=kb_id).all()
|
||||
|
||||
# 关闭当前会话,因为 delete_data_by_file_id 会创建自己的会话
|
||||
|
||||
session.close()
|
||||
|
||||
# 遍历删除每个关联文件及其数据
|
||||
|
||||
for file_obj in files_to_delete:
|
||||
try:
|
||||
await self.delete_data_by_file_id(file_obj.id)
|
||||
except Exception as file_del_e:
|
||||
self.logger.error(f'Failed to delete file ID {file_obj.id} during KB deletion: {file_del_e}')
|
||||
# 记录错误但继续,尝试删除其他文件
|
||||
|
||||
|
||||
# 所有文件删除完毕后,重新打开会话来删除 KnowledgeBase 本身
|
||||
|
||||
session = SessionLocal()
|
||||
try:
|
||||
# 重新查询,确保对象是当前会话的一部分
|
||||
|
||||
kb_final_delete = session.query(KnowledgeBase).filter_by(id=kb_id).first()
|
||||
if kb_final_delete:
|
||||
session.delete(kb_final_delete)
|
||||
|
||||
Reference in New Issue
Block a user