mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 11:29:39 +08:00
perf: ruff check --fix
This commit is contained in:
@@ -1,19 +1,20 @@
|
||||
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary
|
||||
from sqlalchemy.orm import declarative_base, sessionmaker, relationship
|
||||
from datetime import datetime
|
||||
import numpy as np # 用于处理从LargeBinary转换回来的embedding
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
class KnowledgeBase(Base):
|
||||
__tablename__ = 'kb'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, index=True)
|
||||
description = Column(Text)
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
embedding_model = Column(String, default="") # 默认嵌入模型
|
||||
embedding_model = Column(String, default='') # 默认嵌入模型
|
||||
top_k = Column(Integer, default=5) # 默认返回的top_k数量
|
||||
files = relationship("File", back_populates="knowledge_base")
|
||||
files = relationship('File', back_populates='knowledge_base')
|
||||
|
||||
|
||||
class File(Base):
|
||||
__tablename__ = 'file'
|
||||
@@ -24,8 +25,9 @@ class File(Base):
|
||||
created_at = Column(DateTime, default=datetime.utcnow)
|
||||
file_type = Column(String)
|
||||
status = Column(Integer, default=0) # 0: 未处理, 1: 处理中, 2: 已处理, 3: 错误
|
||||
knowledge_base = relationship("KnowledgeBase", back_populates="files")
|
||||
chunks = relationship("Chunk", back_populates="file")
|
||||
knowledge_base = relationship('KnowledgeBase', back_populates='files')
|
||||
chunks = relationship('Chunk', back_populates='file')
|
||||
|
||||
|
||||
class Chunk(Base):
|
||||
__tablename__ = 'chunks'
|
||||
@@ -33,26 +35,30 @@ class Chunk(Base):
|
||||
file_id = Column(Integer, ForeignKey('file.id'))
|
||||
text = Column(Text)
|
||||
|
||||
file = relationship("File", back_populates="chunks")
|
||||
vector = relationship("Vector", uselist=False, back_populates="chunk") # One-to-one
|
||||
file = relationship('File', back_populates='chunks')
|
||||
vector = relationship('Vector', uselist=False, back_populates='chunk') # One-to-one
|
||||
|
||||
|
||||
class Vector(Base):
|
||||
__tablename__ = 'vectors'
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True)
|
||||
embedding = Column(LargeBinary) # Store embeddings as binary
|
||||
embedding = Column(LargeBinary) # Store embeddings as binary
|
||||
|
||||
chunk = relationship('Chunk', back_populates='vector')
|
||||
|
||||
chunk = relationship("Chunk", back_populates="vector")
|
||||
|
||||
# 数据库连接
|
||||
DATABASE_URL = "sqlite:///./knowledge_base.db" # 生产环境请更换为 PostgreSQL/MySQL
|
||||
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False} if "sqlite" in DATABASE_URL else {})
|
||||
DATABASE_URL = 'sqlite:///./knowledge_base.db' # 生产环境请更换为 PostgreSQL/MySQL
|
||||
engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False} if 'sqlite' in DATABASE_URL else {})
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
# 创建所有表 (可以在应用启动时执行一次)
|
||||
def create_db_and_tables():
|
||||
Base.metadata.create_all(bind=engine)
|
||||
print("Database tables created/checked.")
|
||||
print('Database tables created/checked.')
|
||||
|
||||
|
||||
# 定义嵌入维度(请根据你实际使用的模型调整)
|
||||
EMBEDDING_DIM = 1024
|
||||
EMBEDDING_DIM = 1024
|
||||
|
||||
@@ -1,14 +1,15 @@
|
||||
# services/embedding_models.py
|
||||
|
||||
import os
|
||||
from typing import Dict, Any, List, Type, Optional
|
||||
from typing import Dict, Any, List
|
||||
import logging
|
||||
import aiohttp # Import aiohttp for asynchronous requests
|
||||
import aiohttp # Import aiohttp for asynchronous requests
|
||||
import asyncio
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Base class for all embedding models
|
||||
class BaseEmbeddingModel:
|
||||
def __init__(self, model_name: str):
|
||||
@@ -27,9 +28,10 @@ class BaseEmbeddingModel:
|
||||
def embedding_dimension(self) -> int:
|
||||
"""Returns the embedding dimension of the model."""
|
||||
if self._embedding_dimension is None:
|
||||
raise NotImplementedError("Embedding dimension not set for this model.")
|
||||
raise NotImplementedError('Embedding dimension not set for this model.')
|
||||
return self._embedding_dimension
|
||||
|
||||
|
||||
|
||||
class EmbeddingModelFactory:
|
||||
@staticmethod
|
||||
def create_model(model_type: str, model_name_key: str) -> BaseEmbeddingModel:
|
||||
@@ -39,26 +41,29 @@ class EmbeddingModelFactory:
|
||||
"""
|
||||
if model_name_key not in EMBEDDING_MODEL_CONFIGS:
|
||||
raise ValueError(f"Embedding model configuration '{model_name_key}' not found in EMBEDDING_MODEL_CONFIGS.")
|
||||
|
||||
|
||||
config = EMBEDDING_MODEL_CONFIGS[model_name_key]
|
||||
|
||||
if config['type'] == "third_party_api":
|
||||
|
||||
if config['type'] == 'third_party_api':
|
||||
required_keys = ['api_endpoint', 'headers', 'payload_template', 'embedding_dimension']
|
||||
if not all(key in config for key in required_keys):
|
||||
raise ValueError(f"Missing configuration keys for third_party_api model '{model_name_key}'. Required: {required_keys}")
|
||||
|
||||
raise ValueError(
|
||||
f"Missing configuration keys for third_party_api model '{model_name_key}'. Required: {required_keys}"
|
||||
)
|
||||
|
||||
# Retrieve model_name from config if it differs from model_name_key
|
||||
# Some APIs expect a specific 'model' value in the payload that might be different from the key
|
||||
api_model_name = config.get('model_name', model_name_key)
|
||||
api_model_name = config.get('model_name', model_name_key)
|
||||
|
||||
return ThirdPartyAPIEmbeddingModel(
|
||||
model_name=api_model_name, # Use the model_name from config or the key
|
||||
model_name=api_model_name, # Use the model_name from config or the key
|
||||
api_endpoint=config['api_endpoint'],
|
||||
headers=config['headers'],
|
||||
payload_template=config['payload_template'],
|
||||
embedding_dimension=config['embedding_dimension']
|
||||
embedding_dimension=config['embedding_dimension'],
|
||||
)
|
||||
|
||||
|
||||
class SentenceTransformerEmbeddingModel(BaseEmbeddingModel):
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__(model_name)
|
||||
@@ -68,9 +73,11 @@ class SentenceTransformerEmbeddingModel(BaseEmbeddingModel):
|
||||
# if not run in a separate thread/process, but this keeps the API consistent.
|
||||
self.model = SentenceTransformer(model_name)
|
||||
self._embedding_dimension = self.model.get_sentence_embedding_dimension()
|
||||
logger.info(f"Initialized SentenceTransformer model '{model_name}' with dimension {self._embedding_dimension}")
|
||||
logger.info(
|
||||
f"Initialized SentenceTransformer model '{model_name}' with dimension {self._embedding_dimension}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load SentenceTransformer model {model_name}: {e}")
|
||||
logger.error(f'Failed to load SentenceTransformer model {model_name}: {e}')
|
||||
raise
|
||||
|
||||
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
@@ -84,14 +91,23 @@ class SentenceTransformerEmbeddingModel(BaseEmbeddingModel):
|
||||
|
||||
|
||||
class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel):
|
||||
def __init__(self, model_name: str, api_endpoint: str, headers: Dict[str, str], payload_template: Dict[str, Any], embedding_dimension: int):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_endpoint: str,
|
||||
headers: Dict[str, str],
|
||||
payload_template: Dict[str, Any],
|
||||
embedding_dimension: int,
|
||||
):
|
||||
super().__init__(model_name)
|
||||
self.api_endpoint = api_endpoint
|
||||
self.headers = headers
|
||||
self.payload_template = payload_template
|
||||
self._embedding_dimension = embedding_dimension
|
||||
self.session = None # aiohttp client session will be initialized on first use or in a context manager
|
||||
logger.info(f"Initialized ThirdPartyAPIEmbeddingModel '{model_name}' for async calls to {api_endpoint} with dimension {embedding_dimension}")
|
||||
self.session = None # aiohttp client session will be initialized on first use or in a context manager
|
||||
logger.info(
|
||||
f"Initialized ThirdPartyAPIEmbeddingModel '{model_name}' for async calls to {api_endpoint} with dimension {embedding_dimension}"
|
||||
)
|
||||
|
||||
async def _get_session(self):
|
||||
"""Lazily create or return the aiohttp client session."""
|
||||
@@ -104,7 +120,7 @@ class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel):
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
logger.info(f"Closed aiohttp session for model {self.model_name}")
|
||||
logger.info(f'Closed aiohttp session for model {self.model_name}')
|
||||
|
||||
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronously embeds a list of texts using the third-party API."""
|
||||
@@ -118,10 +134,10 @@ class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel):
|
||||
elif 'texts' in payload:
|
||||
payload['texts'] = [text]
|
||||
else:
|
||||
raise ValueError("Payload template does not contain expected text input key.")
|
||||
raise ValueError('Payload template does not contain expected text input key.')
|
||||
|
||||
tasks.append(self._make_api_request(session, payload))
|
||||
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for i, res in enumerate(results):
|
||||
@@ -131,93 +147,92 @@ class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel):
|
||||
# - Append None or an empty list
|
||||
# - Re-raise the exception to stop processing
|
||||
# - Log and skip, then continue
|
||||
embeddings.append([0.0] * self.embedding_dimension) # Append dummy embedding or handle failure
|
||||
embeddings.append([0.0] * self.embedding_dimension) # Append dummy embedding or handle failure
|
||||
else:
|
||||
embeddings.append(res)
|
||||
|
||||
|
||||
return embeddings
|
||||
|
||||
async def _make_api_request(self, session: aiohttp.ClientSession, payload: Dict[str, Any]) -> List[float]:
|
||||
"""Helper to make an asynchronous API request and extract embedding."""
|
||||
try:
|
||||
async with session.post(self.api_endpoint, headers=self.headers, json=payload) as response:
|
||||
response.raise_for_status() # Raise an exception for HTTP errors (4xx, 5xx)
|
||||
response.raise_for_status() # Raise an exception for HTTP errors (4xx, 5xx)
|
||||
api_response = await response.json()
|
||||
|
||||
|
||||
# Adjust this based on your API's actual response structure
|
||||
if "data" in api_response and len(api_response["data"]) > 0 and "embedding" in api_response["data"][0]:
|
||||
embedding = api_response["data"][0]["embedding"]
|
||||
if 'data' in api_response and len(api_response['data']) > 0 and 'embedding' in api_response['data'][0]:
|
||||
embedding = api_response['data'][0]['embedding']
|
||||
if len(embedding) != self.embedding_dimension:
|
||||
logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.")
|
||||
logger.warning(
|
||||
f'API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.'
|
||||
)
|
||||
return embedding
|
||||
elif "embeddings" in api_response and isinstance(api_response["embeddings"], list) and api_response["embeddings"]:
|
||||
embedding = api_response["embeddings"][0]
|
||||
elif (
|
||||
'embeddings' in api_response
|
||||
and isinstance(api_response['embeddings'], list)
|
||||
and api_response['embeddings']
|
||||
):
|
||||
embedding = api_response['embeddings'][0]
|
||||
if len(embedding) != self.embedding_dimension:
|
||||
logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.")
|
||||
logger.warning(
|
||||
f'API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.'
|
||||
)
|
||||
return embedding
|
||||
else:
|
||||
raise ValueError(f"Unexpected API response structure: {api_response}")
|
||||
raise ValueError(f'Unexpected API response structure: {api_response}')
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise ConnectionError(f"API request failed: {e}") from e
|
||||
raise ConnectionError(f'API request failed: {e}') from e
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Error processing API response: {e}") from e
|
||||
|
||||
raise ValueError(f'Error processing API response: {e}') from e
|
||||
|
||||
async def embed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronously embeds a single query text."""
|
||||
results = await self.embed_documents([text])
|
||||
if results:
|
||||
return results[0]
|
||||
return [] # Or raise an error if embedding a query must always succeed
|
||||
return [] # Or raise an error if embedding a query must always succeed
|
||||
|
||||
|
||||
# --- Embedding Model Configuration ---
|
||||
EMBEDDING_MODEL_CONFIGS: Dict[str, Dict[str, Any]] = {
|
||||
"MiniLM": { # Example for a local Sentence Transformer model
|
||||
"type": "sentence_transformer",
|
||||
"model_name": "sentence-transformers/all-MiniLM-L6-v2"
|
||||
'MiniLM': { # Example for a local Sentence Transformer model
|
||||
'type': 'sentence_transformer',
|
||||
'model_name': 'sentence-transformers/all-MiniLM-L6-v2',
|
||||
},
|
||||
"bge-m3": { # Example for a third-party API model
|
||||
"type": "third_party_api",
|
||||
"model_name": "bge-m3",
|
||||
"api_endpoint": "https://api.qhaigc.net/v1/embeddings",
|
||||
"headers": {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.getenv('rag_api_key')}"
|
||||
},
|
||||
"payload_template": {
|
||||
"model": "bge-m3",
|
||||
"input": ""
|
||||
},
|
||||
"embedding_dimension": 1024
|
||||
'bge-m3': { # Example for a third-party API model
|
||||
'type': 'third_party_api',
|
||||
'model_name': 'bge-m3',
|
||||
'api_endpoint': 'https://api.qhaigc.net/v1/embeddings',
|
||||
'headers': {'Content-Type': 'application/json', 'Authorization': f'Bearer {os.getenv("rag_api_key")}'},
|
||||
'payload_template': {'model': 'bge-m3', 'input': ''},
|
||||
'embedding_dimension': 1024,
|
||||
},
|
||||
"OpenAI-Ada-002": {
|
||||
"type": "third_party_api",
|
||||
"model_name": "text-embedding-ada-002",
|
||||
"api_endpoint": "https://api.openai.com/v1/embeddings",
|
||||
"headers": {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}" # Ensure OPENAI_API_KEY is set
|
||||
'OpenAI-Ada-002': {
|
||||
'type': 'third_party_api',
|
||||
'model_name': 'text-embedding-ada-002',
|
||||
'api_endpoint': 'https://api.openai.com/v1/embeddings',
|
||||
'headers': {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {os.getenv("OPENAI_API_KEY")}', # Ensure OPENAI_API_KEY is set
|
||||
},
|
||||
"payload_template": {
|
||||
"model": "text-embedding-ada-002",
|
||||
"input": "" # Text will be injected here
|
||||
'payload_template': {
|
||||
'model': 'text-embedding-ada-002',
|
||||
'input': '', # Text will be injected here
|
||||
},
|
||||
"embedding_dimension": 1536
|
||||
'embedding_dimension': 1536,
|
||||
},
|
||||
"OpenAI-Embedding-3-Small": {
|
||||
"type": "third_party_api",
|
||||
"model_name": "text-embedding-3-small",
|
||||
"api_endpoint": "https://api.openai.com/v1/embeddings",
|
||||
"headers": {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"
|
||||
},
|
||||
"payload_template": {
|
||||
"model": "text-embedding-3-small",
|
||||
"input": "",
|
||||
'OpenAI-Embedding-3-Small': {
|
||||
'type': 'third_party_api',
|
||||
'model_name': 'text-embedding-3-small',
|
||||
'api_endpoint': 'https://api.openai.com/v1/embeddings',
|
||||
'headers': {'Content-Type': 'application/json', 'Authorization': f'Bearer {os.getenv("OPENAI_API_KEY")}'},
|
||||
'payload_template': {
|
||||
'model': 'text-embedding-3-small',
|
||||
'input': '',
|
||||
# "dimensions": 512 # Optional: uncomment if you want a specific output dimension
|
||||
},
|
||||
"embedding_dimension": 1536 # Default max dimension for text-embedding-3-small
|
||||
'embedding_dimension': 1536, # Default max dimension for text-embedding-3-small
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,22 +1,21 @@
|
||||
|
||||
import PyPDF2
|
||||
from docx import Document
|
||||
import pandas as pd
|
||||
import csv
|
||||
import chardet
|
||||
from typing import Union, List, Callable, Any
|
||||
from typing import Union, Callable, Any
|
||||
import logging
|
||||
import markdown
|
||||
from bs4 import BeautifulSoup
|
||||
import ebooklib
|
||||
from ebooklib import epub
|
||||
import re
|
||||
import asyncio # Import asyncio for async operations
|
||||
import asyncio # Import asyncio for async operations
|
||||
import os
|
||||
|
||||
# Configure logging
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FileParser:
|
||||
"""
|
||||
A robust file parser class to extract text content from various document formats.
|
||||
@@ -24,8 +23,8 @@ class FileParser:
|
||||
All core file reading operations are designed to be run synchronously in a thread pool
|
||||
to avoid blocking the asyncio event loop.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
|
||||
self.logger = logging.getLogger(self.__class__.__name__)
|
||||
|
||||
async def _run_sync(self, sync_func: Callable, *args: Any, **kwargs: Any) -> Any:
|
||||
@@ -36,14 +35,14 @@ class FileParser:
|
||||
try:
|
||||
return await asyncio.to_thread(sync_func, *args, **kwargs)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error running synchronous function {sync_func.__name__}: {e}")
|
||||
self.logger.error(f'Error running synchronous function {sync_func.__name__}: {e}')
|
||||
raise
|
||||
|
||||
async def parse(self, file_path: str) -> Union[str, None]:
|
||||
"""
|
||||
Parses the file based on its extension and returns the extracted text content.
|
||||
This is the main asynchronous entry point for parsing.
|
||||
|
||||
|
||||
Args:
|
||||
file_path (str): The path to the file to be parsed.
|
||||
|
||||
@@ -51,21 +50,21 @@ class FileParser:
|
||||
Union[str, None]: The extracted text content as a single string, or None if parsing fails.
|
||||
"""
|
||||
if not file_path or not os.path.exists(file_path):
|
||||
self.logger.error(f"Invalid file path provided: {file_path}")
|
||||
self.logger.error(f'Invalid file path provided: {file_path}')
|
||||
return None
|
||||
|
||||
file_extension = file_path.split('.')[-1].lower()
|
||||
parser_method = getattr(self, f'_parse_{file_extension}', None)
|
||||
|
||||
|
||||
if parser_method is None:
|
||||
self.logger.error(f"Unsupported file format: {file_extension} for file {file_path}")
|
||||
self.logger.error(f'Unsupported file format: {file_extension} for file {file_path}')
|
||||
return None
|
||||
|
||||
|
||||
try:
|
||||
# Pass file_path to the specific parser methods
|
||||
return await parser_method(file_path)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to parse {file_extension} file {file_path}: {e}")
|
||||
self.logger.error(f'Failed to parse {file_extension} file {file_path}: {e}')
|
||||
return None
|
||||
|
||||
# --- Helper for reading files with encoding detection ---
|
||||
@@ -74,15 +73,16 @@ class FileParser:
|
||||
Reads a file with automatic encoding detection, ensuring the synchronous
|
||||
file read operation runs in a separate thread.
|
||||
"""
|
||||
|
||||
def _read_sync():
|
||||
with open(file_path, 'rb') as file:
|
||||
raw_data = file.read()
|
||||
detected = chardet.detect(raw_data)
|
||||
encoding = detected['encoding'] or 'utf-8'
|
||||
|
||||
|
||||
if mode == 'r':
|
||||
return raw_data.decode(encoding, errors='ignore')
|
||||
return raw_data # For binary mode
|
||||
return raw_data # For binary mode
|
||||
|
||||
return await self._run_sync(_read_sync)
|
||||
|
||||
@@ -90,12 +90,13 @@ class FileParser:
|
||||
|
||||
async def _parse_txt(self, file_path: str) -> str:
|
||||
"""Parses a TXT file and returns its content."""
|
||||
self.logger.info(f"Parsing TXT file: {file_path}")
|
||||
self.logger.info(f'Parsing TXT file: {file_path}')
|
||||
return await self._read_file_content(file_path, mode='r')
|
||||
|
||||
async def _parse_pdf(self, file_path: str) -> str:
|
||||
"""Parses a PDF file and returns its text content."""
|
||||
self.logger.info(f"Parsing PDF file: {file_path}")
|
||||
self.logger.info(f'Parsing PDF file: {file_path}')
|
||||
|
||||
def _parse_pdf_sync():
|
||||
text_content = []
|
||||
with open(file_path, 'rb') as file:
|
||||
@@ -105,57 +106,69 @@ class FileParser:
|
||||
if text:
|
||||
text_content.append(text)
|
||||
return '\n'.join(text_content)
|
||||
|
||||
return await self._run_sync(_parse_pdf_sync)
|
||||
|
||||
async def _parse_docx(self, file_path: str) -> str:
|
||||
"""Parses a DOCX file and returns its text content."""
|
||||
self.logger.info(f"Parsing DOCX file: {file_path}")
|
||||
self.logger.info(f'Parsing DOCX file: {file_path}')
|
||||
|
||||
def _parse_docx_sync():
|
||||
doc = Document(file_path)
|
||||
text_content = [paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()]
|
||||
return '\n'.join(text_content)
|
||||
|
||||
return await self._run_sync(_parse_docx_sync)
|
||||
|
||||
|
||||
async def _parse_doc(self, file_path: str) -> str:
|
||||
"""Handles .doc files, explicitly stating lack of direct support."""
|
||||
self.logger.warning(f"Direct .doc parsing is not supported for {file_path}. Please convert to .docx first.")
|
||||
raise NotImplementedError("Direct .doc parsing not supported. Please convert to .docx first.")
|
||||
|
||||
self.logger.warning(f'Direct .doc parsing is not supported for {file_path}. Please convert to .docx first.')
|
||||
raise NotImplementedError('Direct .doc parsing not supported. Please convert to .docx first.')
|
||||
|
||||
async def _parse_xlsx(self, file_path: str) -> str:
|
||||
"""Parses an XLSX file, returning text from all sheets."""
|
||||
self.logger.info(f"Parsing XLSX file: {file_path}")
|
||||
self.logger.info(f'Parsing XLSX file: {file_path}')
|
||||
|
||||
def _parse_xlsx_sync():
|
||||
excel_file = pd.ExcelFile(file_path)
|
||||
all_sheet_content = []
|
||||
for sheet_name in excel_file.sheet_names:
|
||||
df = pd.read_excel(file_path, sheet_name=sheet_name)
|
||||
sheet_text = f"--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n"
|
||||
sheet_text = f'--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n'
|
||||
all_sheet_content.append(sheet_text)
|
||||
return '\n'.join(all_sheet_content)
|
||||
|
||||
return await self._run_sync(_parse_xlsx_sync)
|
||||
|
||||
|
||||
async def _parse_csv(self, file_path: str) -> str:
|
||||
"""Parses a CSV file and returns its content as a string."""
|
||||
self.logger.info(f"Parsing CSV file: {file_path}")
|
||||
self.logger.info(f'Parsing CSV file: {file_path}')
|
||||
|
||||
def _parse_csv_sync():
|
||||
# pd.read_csv can often detect encoding, but explicit detection is safer
|
||||
raw_data = self._read_file_content(file_path, mode='rb') # Note: this will need to be await outside this sync function
|
||||
raw_data = self._read_file_content(
|
||||
file_path, mode='rb'
|
||||
) # Note: this will need to be await outside this sync function
|
||||
_ = raw_data
|
||||
# For simplicity, we'll let pandas handle encoding internally after a raw read.
|
||||
# A more robust solution might pass encoding directly to pd.read_csv after detection.
|
||||
detected = chardet.detect(open(file_path, 'rb').read())
|
||||
encoding = detected['encoding'] or 'utf-8'
|
||||
df = pd.read_csv(file_path, encoding=encoding)
|
||||
return df.to_string(index=False)
|
||||
|
||||
return await self._run_sync(_parse_csv_sync)
|
||||
|
||||
|
||||
async def _parse_markdown(self, file_path: str) -> str:
|
||||
"""Parses a Markdown file, converting it to structured plain text."""
|
||||
self.logger.info(f"Parsing Markdown file: {file_path}")
|
||||
self.logger.info(f'Parsing Markdown file: {file_path}')
|
||||
|
||||
def _parse_markdown_sync():
|
||||
md_content = self._read_file_content(file_path, mode='r') # This is a synchronous call within a sync function
|
||||
md_content = self._read_file_content(
|
||||
file_path, mode='r'
|
||||
) # This is a synchronous call within a sync function
|
||||
html_content = markdown.markdown(
|
||||
md_content,
|
||||
extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code']
|
||||
md_content, extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code']
|
||||
)
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
text_parts = []
|
||||
@@ -169,13 +182,13 @@ class FileParser:
|
||||
text_parts.append(text)
|
||||
elif element.name in ['ul', 'ol']:
|
||||
for li in element.find_all('li'):
|
||||
text_parts.append(f"* {li.get_text().strip()}")
|
||||
text_parts.append(f'* {li.get_text().strip()}')
|
||||
elif element.name == 'pre':
|
||||
code_block = element.get_text().strip()
|
||||
if code_block:
|
||||
text_parts.append(f"```\n{code_block}\n```")
|
||||
text_parts.append(f'```\n{code_block}\n```')
|
||||
elif element.name == 'table':
|
||||
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
|
||||
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
|
||||
if table_str:
|
||||
text_parts.append(table_str)
|
||||
elif element.name:
|
||||
@@ -184,15 +197,17 @@ class FileParser:
|
||||
text_parts.append(text)
|
||||
cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts))
|
||||
return cleaned_text.strip()
|
||||
|
||||
return await self._run_sync(_parse_markdown_sync)
|
||||
|
||||
async def _parse_html(self, file_path: str) -> str:
|
||||
"""Parses an HTML file, extracting structured plain text."""
|
||||
self.logger.info(f"Parsing HTML file: {file_path}")
|
||||
self.logger.info(f'Parsing HTML file: {file_path}')
|
||||
|
||||
def _parse_html_sync():
|
||||
html_content = self._read_file_content(file_path, mode='r') # Sync call within sync function
|
||||
html_content = self._read_file_content(file_path, mode='r') # Sync call within sync function
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
for script_or_style in soup(["script", "style"]):
|
||||
for script_or_style in soup(['script', 'style']):
|
||||
script_or_style.decompose()
|
||||
text_parts = []
|
||||
for element in soup.body.children if soup.body else soup.children:
|
||||
@@ -207,9 +222,9 @@ class FileParser:
|
||||
for li in element.find_all('li'):
|
||||
text = li.get_text().strip()
|
||||
if text:
|
||||
text_parts.append(f"* {text}")
|
||||
text_parts.append(f'* {text}')
|
||||
elif element.name == 'table':
|
||||
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
|
||||
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
|
||||
if table_str:
|
||||
text_parts.append(table_str)
|
||||
elif element.name:
|
||||
@@ -218,39 +233,42 @@ class FileParser:
|
||||
text_parts.append(text)
|
||||
cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts))
|
||||
return cleaned_text.strip()
|
||||
|
||||
return await self._run_sync(_parse_html_sync)
|
||||
|
||||
|
||||
async def _parse_epub(self, file_path: str) -> str:
|
||||
"""Parses an EPUB file, extracting metadata and content."""
|
||||
self.logger.info(f"Parsing EPUB file: {file_path}")
|
||||
self.logger.info(f'Parsing EPUB file: {file_path}')
|
||||
|
||||
def _parse_epub_sync():
|
||||
book = epub.read_epub(file_path)
|
||||
text_content = []
|
||||
title_meta = book.get_metadata('DC', 'title')
|
||||
if title_meta:
|
||||
text_content.append(f"Title: {title_meta[0][0]}")
|
||||
text_content.append(f'Title: {title_meta[0][0]}')
|
||||
creator_meta = book.get_metadata('DC', 'creator')
|
||||
if creator_meta:
|
||||
text_content.append(f"Author: {creator_meta[0][0]}")
|
||||
text_content.append(f'Author: {creator_meta[0][0]}')
|
||||
date_meta = book.get_metadata('DC', 'date')
|
||||
if date_meta:
|
||||
text_content.append(f"Publish Date: {date_meta[0][0]}")
|
||||
text_content.append(f'Publish Date: {date_meta[0][0]}')
|
||||
toc = book.get_toc()
|
||||
if toc:
|
||||
text_content.append("\n--- Table of Contents ---")
|
||||
self._add_toc_items_sync(toc, text_content, level=0) # Call sync helper
|
||||
text_content.append("--- End of Table of Contents ---\n")
|
||||
text_content.append('\n--- Table of Contents ---')
|
||||
self._add_toc_items_sync(toc, text_content, level=0) # Call sync helper
|
||||
text_content.append('--- End of Table of Contents ---\n')
|
||||
for item in book.get_items():
|
||||
if item.get_type() == ebooklib.ITEM_DOCUMENT:
|
||||
html_content = item.get_content().decode('utf-8', errors='ignore')
|
||||
soup = BeautifulSoup(html_content, 'html.parser')
|
||||
for junk in soup(["script", "style", "nav", "header", "footer"]):
|
||||
for junk in soup(['script', 'style', 'nav', 'header', 'footer']):
|
||||
junk.decompose()
|
||||
text = soup.get_text(separator='\n', strip=True)
|
||||
text = re.sub(r'\n\s*\n', '\n\n', text)
|
||||
if text:
|
||||
text_content.append(text)
|
||||
return re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_content)).strip()
|
||||
|
||||
return await self._run_sync(_parse_epub_sync)
|
||||
|
||||
def _add_toc_items_sync(self, toc_list: list, text_content: list, level: int):
|
||||
@@ -259,10 +277,10 @@ class FileParser:
|
||||
for item in toc_list:
|
||||
if isinstance(item, tuple):
|
||||
chapter, subchapters = item
|
||||
text_content.append(f"{indent}- {chapter.title}")
|
||||
text_content.append(f'{indent}- {chapter.title}')
|
||||
self._add_toc_items_sync(subchapters, text_content, level + 1)
|
||||
else:
|
||||
text_content.append(f"{indent}- {item.title}")
|
||||
text_content.append(f'{indent}- {item.title}')
|
||||
|
||||
def _extract_table_to_markdown_sync(self, table_element: BeautifulSoup) -> str:
|
||||
"""Helper to convert a BeautifulSoup table element into a Markdown table string (synchronous)."""
|
||||
@@ -272,17 +290,17 @@ class FileParser:
|
||||
cells = [td.get_text().strip() for td in tr.find_all('td')]
|
||||
if cells:
|
||||
rows.append(cells)
|
||||
|
||||
|
||||
if not headers and not rows:
|
||||
return ""
|
||||
return ''
|
||||
|
||||
table_lines = []
|
||||
if headers:
|
||||
table_lines.append(' | '.join(headers))
|
||||
table_lines.append(' | '.join(['---'] * len(headers)))
|
||||
|
||||
|
||||
for row_cells in rows:
|
||||
padded_cells = row_cells + [''] * (len(headers) - len(row_cells)) if headers else row_cells
|
||||
table_lines.append(' | '.join(padded_cells))
|
||||
|
||||
return '\n'.join(table_lines)
|
||||
|
||||
return '\n'.join(table_lines)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# services/retriever.py
|
||||
import asyncio
|
||||
import logging
|
||||
import numpy as np # Make sure numpy is imported
|
||||
import numpy as np # Make sure numpy is imported
|
||||
from typing import List, Dict, Any
|
||||
from sqlalchemy.orm import Session
|
||||
from pkg.rag.knowledge.services.base_service import BaseService
|
||||
@@ -11,6 +10,7 @@ from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Retriever(BaseService):
|
||||
def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager):
|
||||
super().__init__()
|
||||
@@ -22,10 +22,14 @@ class Retriever(BaseService):
|
||||
self.embedding_model: BaseEmbeddingModel = self._load_embedding_model()
|
||||
|
||||
def _load_embedding_model(self) -> BaseEmbeddingModel:
|
||||
self.logger.info(f"Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...")
|
||||
self.logger.info(
|
||||
f'Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...'
|
||||
)
|
||||
try:
|
||||
model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key)
|
||||
self.logger.info(f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}")
|
||||
self.logger.info(
|
||||
f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}"
|
||||
)
|
||||
return model
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load retriever embedding model '{self.model_name_key}': {e}")
|
||||
@@ -33,43 +37,42 @@ class Retriever(BaseService):
|
||||
|
||||
async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
|
||||
if not self.embedding_model:
|
||||
raise RuntimeError("Retriever embedding model not loaded. Please check Retriever initialization.")
|
||||
raise RuntimeError('Retriever embedding model not loaded. Please check Retriever initialization.')
|
||||
|
||||
self.logger.info(f"Retrieving for query: '{query}' with k={k} using {self.model_name_key}")
|
||||
|
||||
query_embedding: List[float] = await self.embedding_model.embed_query(query)
|
||||
query_embedding_np = np.array([query_embedding], dtype=np.float32)
|
||||
|
||||
chroma_results = await self._run_sync(
|
||||
self.chroma_manager.search_sync,
|
||||
query_embedding_np, k
|
||||
)
|
||||
chroma_results = await self._run_sync(self.chroma_manager.search_sync, query_embedding_np, k)
|
||||
|
||||
# 'ids' is always returned by ChromaDB, even if not explicitly in 'include'
|
||||
matched_chroma_ids = chroma_results.get("ids", [[]])[0]
|
||||
distances = chroma_results.get("distances", [[]])[0]
|
||||
chroma_metadatas = chroma_results.get("metadatas", [[]])[0]
|
||||
chroma_documents = chroma_results.get("documents", [[]])[0]
|
||||
matched_chroma_ids = chroma_results.get('ids', [[]])[0]
|
||||
distances = chroma_results.get('distances', [[]])[0]
|
||||
chroma_metadatas = chroma_results.get('metadatas', [[]])[0]
|
||||
chroma_documents = chroma_results.get('documents', [[]])[0]
|
||||
|
||||
if not matched_chroma_ids:
|
||||
self.logger.info("No relevant chunks found in Chroma.")
|
||||
self.logger.info('No relevant chunks found in Chroma.')
|
||||
return []
|
||||
|
||||
db_chunk_ids = []
|
||||
for metadata in chroma_metadatas:
|
||||
if "chunk_id" in metadata:
|
||||
db_chunk_ids.append(metadata["chunk_id"])
|
||||
if 'chunk_id' in metadata:
|
||||
db_chunk_ids.append(metadata['chunk_id'])
|
||||
else:
|
||||
self.logger.warning(f"Metadata missing 'chunk_id': {metadata}. Skipping this entry.")
|
||||
|
||||
if not db_chunk_ids:
|
||||
self.logger.warning("No valid chunk_ids extracted from Chroma results metadata.")
|
||||
self.logger.warning('No valid chunk_ids extracted from Chroma results metadata.')
|
||||
return []
|
||||
|
||||
self.logger.info(f"Fetching {len(db_chunk_ids)} chunk details from relational database...")
|
||||
self.logger.info(f'Fetching {len(db_chunk_ids)} chunk details from relational database...')
|
||||
chunks_from_db = await self._run_sync(
|
||||
lambda cids: self._db_get_chunks_sync(SessionLocal(), cids), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync
|
||||
db_chunk_ids
|
||||
lambda cids: self._db_get_chunks_sync(
|
||||
SessionLocal(), cids
|
||||
), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync
|
||||
db_chunk_ids,
|
||||
)
|
||||
|
||||
chunk_map = {chunk.id: chunk for chunk in chunks_from_db}
|
||||
@@ -80,27 +83,29 @@ class Retriever(BaseService):
|
||||
# Ensure original_chunk_id is int for DB lookup
|
||||
original_chunk_id = int(chroma_id.split('_')[-1])
|
||||
except (ValueError, IndexError):
|
||||
self.logger.warning(f"Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.")
|
||||
self.logger.warning(f'Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.')
|
||||
continue
|
||||
|
||||
chunk_text_from_chroma = chroma_documents[i]
|
||||
distance = float(distances[i])
|
||||
file_id_from_chroma = chroma_metadatas[i].get("file_id")
|
||||
file_id_from_chroma = chroma_metadatas[i].get('file_id')
|
||||
|
||||
chunk_from_db = chunk_map.get(original_chunk_id)
|
||||
|
||||
results_list.append({
|
||||
"chunk_id": original_chunk_id,
|
||||
"text": chunk_from_db.text if chunk_from_db else chunk_text_from_chroma,
|
||||
"distance": distance,
|
||||
"file_id": file_id_from_chroma
|
||||
})
|
||||
results_list.append(
|
||||
{
|
||||
'chunk_id': original_chunk_id,
|
||||
'text': chunk_from_db.text if chunk_from_db else chunk_text_from_chroma,
|
||||
'distance': distance,
|
||||
'file_id': file_id_from_chroma,
|
||||
}
|
||||
)
|
||||
|
||||
self.logger.info(f"Retrieved {len(results_list)} chunks.")
|
||||
self.logger.info(f'Retrieved {len(results_list)} chunks.')
|
||||
return results_list
|
||||
|
||||
def _db_get_chunks_sync(self, session: Session, chunk_ids: List[int]) -> List[Chunk]:
|
||||
self.logger.debug(f"Fetching {len(chunk_ids)} chunk details from database (sync).")
|
||||
self.logger.debug(f'Fetching {len(chunk_ids)} chunk details from database (sync).')
|
||||
chunks = session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all()
|
||||
session.close()
|
||||
return chunks
|
||||
return chunks
|
||||
|
||||
Reference in New Issue
Block a user