perf: ruff check --fix

This commit is contained in:
Junyan Qin
2025-07-05 21:56:54 +08:00
parent 39c062f73e
commit 8d28ace252
23 changed files with 647 additions and 737 deletions

View File

@@ -1,19 +1,20 @@
from sqlalchemy import create_engine, Column, Integer, String, Text, DateTime, ForeignKey, LargeBinary
from sqlalchemy.orm import declarative_base, sessionmaker, relationship
from datetime import datetime
import numpy as np # 用于处理从LargeBinary转换回来的embedding
Base = declarative_base()
class KnowledgeBase(Base):
__tablename__ = 'kb'
id = Column(Integer, primary_key=True, index=True)
name = Column(String, index=True)
description = Column(Text)
created_at = Column(DateTime, default=datetime.utcnow)
embedding_model = Column(String, default="") # 默认嵌入模型
embedding_model = Column(String, default='') # 默认嵌入模型
top_k = Column(Integer, default=5) # 默认返回的top_k数量
files = relationship("File", back_populates="knowledge_base")
files = relationship('File', back_populates='knowledge_base')
class File(Base):
__tablename__ = 'file'
@@ -24,8 +25,9 @@ class File(Base):
created_at = Column(DateTime, default=datetime.utcnow)
file_type = Column(String)
status = Column(Integer, default=0) # 0: 未处理, 1: 处理中, 2: 已处理, 3: 错误
knowledge_base = relationship("KnowledgeBase", back_populates="files")
chunks = relationship("Chunk", back_populates="file")
knowledge_base = relationship('KnowledgeBase', back_populates='files')
chunks = relationship('Chunk', back_populates='file')
class Chunk(Base):
__tablename__ = 'chunks'
@@ -33,26 +35,30 @@ class Chunk(Base):
file_id = Column(Integer, ForeignKey('file.id'))
text = Column(Text)
file = relationship("File", back_populates="chunks")
vector = relationship("Vector", uselist=False, back_populates="chunk") # One-to-one
file = relationship('File', back_populates='chunks')
vector = relationship('Vector', uselist=False, back_populates='chunk') # One-to-one
class Vector(Base):
__tablename__ = 'vectors'
id = Column(Integer, primary_key=True, index=True)
chunk_id = Column(Integer, ForeignKey('chunks.id'), unique=True)
embedding = Column(LargeBinary) # Store embeddings as binary
embedding = Column(LargeBinary) # Store embeddings as binary
chunk = relationship('Chunk', back_populates='vector')
chunk = relationship("Chunk", back_populates="vector")
# 数据库连接
DATABASE_URL = "sqlite:///./knowledge_base.db" # 生产环境请更换为 PostgreSQL/MySQL
engine = create_engine(DATABASE_URL, connect_args={"check_same_thread": False} if "sqlite" in DATABASE_URL else {})
DATABASE_URL = 'sqlite:///./knowledge_base.db' # 生产环境请更换为 PostgreSQL/MySQL
engine = create_engine(DATABASE_URL, connect_args={'check_same_thread': False} if 'sqlite' in DATABASE_URL else {})
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# 创建所有表 (可以在应用启动时执行一次)
def create_db_and_tables():
Base.metadata.create_all(bind=engine)
print("Database tables created/checked.")
print('Database tables created/checked.')
# 定义嵌入维度(请根据你实际使用的模型调整)
EMBEDDING_DIM = 1024
EMBEDDING_DIM = 1024

View File

@@ -1,14 +1,15 @@
# services/embedding_models.py
import os
from typing import Dict, Any, List, Type, Optional
from typing import Dict, Any, List
import logging
import aiohttp # Import aiohttp for asynchronous requests
import aiohttp # Import aiohttp for asynchronous requests
import asyncio
from sentence_transformers import SentenceTransformer
logger = logging.getLogger(__name__)
# Base class for all embedding models
class BaseEmbeddingModel:
def __init__(self, model_name: str):
@@ -27,9 +28,10 @@ class BaseEmbeddingModel:
def embedding_dimension(self) -> int:
"""Returns the embedding dimension of the model."""
if self._embedding_dimension is None:
raise NotImplementedError("Embedding dimension not set for this model.")
raise NotImplementedError('Embedding dimension not set for this model.')
return self._embedding_dimension
class EmbeddingModelFactory:
@staticmethod
def create_model(model_type: str, model_name_key: str) -> BaseEmbeddingModel:
@@ -39,26 +41,29 @@ class EmbeddingModelFactory:
"""
if model_name_key not in EMBEDDING_MODEL_CONFIGS:
raise ValueError(f"Embedding model configuration '{model_name_key}' not found in EMBEDDING_MODEL_CONFIGS.")
config = EMBEDDING_MODEL_CONFIGS[model_name_key]
if config['type'] == "third_party_api":
if config['type'] == 'third_party_api':
required_keys = ['api_endpoint', 'headers', 'payload_template', 'embedding_dimension']
if not all(key in config for key in required_keys):
raise ValueError(f"Missing configuration keys for third_party_api model '{model_name_key}'. Required: {required_keys}")
raise ValueError(
f"Missing configuration keys for third_party_api model '{model_name_key}'. Required: {required_keys}"
)
# Retrieve model_name from config if it differs from model_name_key
# Some APIs expect a specific 'model' value in the payload that might be different from the key
api_model_name = config.get('model_name', model_name_key)
api_model_name = config.get('model_name', model_name_key)
return ThirdPartyAPIEmbeddingModel(
model_name=api_model_name, # Use the model_name from config or the key
model_name=api_model_name, # Use the model_name from config or the key
api_endpoint=config['api_endpoint'],
headers=config['headers'],
payload_template=config['payload_template'],
embedding_dimension=config['embedding_dimension']
embedding_dimension=config['embedding_dimension'],
)
class SentenceTransformerEmbeddingModel(BaseEmbeddingModel):
def __init__(self, model_name: str):
super().__init__(model_name)
@@ -68,9 +73,11 @@ class SentenceTransformerEmbeddingModel(BaseEmbeddingModel):
# if not run in a separate thread/process, but this keeps the API consistent.
self.model = SentenceTransformer(model_name)
self._embedding_dimension = self.model.get_sentence_embedding_dimension()
logger.info(f"Initialized SentenceTransformer model '{model_name}' with dimension {self._embedding_dimension}")
logger.info(
f"Initialized SentenceTransformer model '{model_name}' with dimension {self._embedding_dimension}"
)
except Exception as e:
logger.error(f"Failed to load SentenceTransformer model {model_name}: {e}")
logger.error(f'Failed to load SentenceTransformer model {model_name}: {e}')
raise
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
@@ -84,14 +91,23 @@ class SentenceTransformerEmbeddingModel(BaseEmbeddingModel):
class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel):
def __init__(self, model_name: str, api_endpoint: str, headers: Dict[str, str], payload_template: Dict[str, Any], embedding_dimension: int):
def __init__(
self,
model_name: str,
api_endpoint: str,
headers: Dict[str, str],
payload_template: Dict[str, Any],
embedding_dimension: int,
):
super().__init__(model_name)
self.api_endpoint = api_endpoint
self.headers = headers
self.payload_template = payload_template
self._embedding_dimension = embedding_dimension
self.session = None # aiohttp client session will be initialized on first use or in a context manager
logger.info(f"Initialized ThirdPartyAPIEmbeddingModel '{model_name}' for async calls to {api_endpoint} with dimension {embedding_dimension}")
self.session = None # aiohttp client session will be initialized on first use or in a context manager
logger.info(
f"Initialized ThirdPartyAPIEmbeddingModel '{model_name}' for async calls to {api_endpoint} with dimension {embedding_dimension}"
)
async def _get_session(self):
"""Lazily create or return the aiohttp client session."""
@@ -104,7 +120,7 @@ class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel):
if self.session and not self.session.closed:
await self.session.close()
self.session = None
logger.info(f"Closed aiohttp session for model {self.model_name}")
logger.info(f'Closed aiohttp session for model {self.model_name}')
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Asynchronously embeds a list of texts using the third-party API."""
@@ -118,10 +134,10 @@ class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel):
elif 'texts' in payload:
payload['texts'] = [text]
else:
raise ValueError("Payload template does not contain expected text input key.")
raise ValueError('Payload template does not contain expected text input key.')
tasks.append(self._make_api_request(session, payload))
results = await asyncio.gather(*tasks, return_exceptions=True)
for i, res in enumerate(results):
@@ -131,93 +147,92 @@ class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel):
# - Append None or an empty list
# - Re-raise the exception to stop processing
# - Log and skip, then continue
embeddings.append([0.0] * self.embedding_dimension) # Append dummy embedding or handle failure
embeddings.append([0.0] * self.embedding_dimension) # Append dummy embedding or handle failure
else:
embeddings.append(res)
return embeddings
async def _make_api_request(self, session: aiohttp.ClientSession, payload: Dict[str, Any]) -> List[float]:
"""Helper to make an asynchronous API request and extract embedding."""
try:
async with session.post(self.api_endpoint, headers=self.headers, json=payload) as response:
response.raise_for_status() # Raise an exception for HTTP errors (4xx, 5xx)
response.raise_for_status() # Raise an exception for HTTP errors (4xx, 5xx)
api_response = await response.json()
# Adjust this based on your API's actual response structure
if "data" in api_response and len(api_response["data"]) > 0 and "embedding" in api_response["data"][0]:
embedding = api_response["data"][0]["embedding"]
if 'data' in api_response and len(api_response['data']) > 0 and 'embedding' in api_response['data'][0]:
embedding = api_response['data'][0]['embedding']
if len(embedding) != self.embedding_dimension:
logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.")
logger.warning(
f'API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.'
)
return embedding
elif "embeddings" in api_response and isinstance(api_response["embeddings"], list) and api_response["embeddings"]:
embedding = api_response["embeddings"][0]
elif (
'embeddings' in api_response
and isinstance(api_response['embeddings'], list)
and api_response['embeddings']
):
embedding = api_response['embeddings'][0]
if len(embedding) != self.embedding_dimension:
logger.warning(f"API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.")
logger.warning(
f'API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.'
)
return embedding
else:
raise ValueError(f"Unexpected API response structure: {api_response}")
raise ValueError(f'Unexpected API response structure: {api_response}')
except aiohttp.ClientError as e:
raise ConnectionError(f"API request failed: {e}") from e
raise ConnectionError(f'API request failed: {e}') from e
except ValueError as e:
raise ValueError(f"Error processing API response: {e}") from e
raise ValueError(f'Error processing API response: {e}') from e
async def embed_query(self, text: str) -> List[float]:
"""Asynchronously embeds a single query text."""
results = await self.embed_documents([text])
if results:
return results[0]
return [] # Or raise an error if embedding a query must always succeed
return [] # Or raise an error if embedding a query must always succeed
# --- Embedding Model Configuration ---
EMBEDDING_MODEL_CONFIGS: Dict[str, Dict[str, Any]] = {
"MiniLM": { # Example for a local Sentence Transformer model
"type": "sentence_transformer",
"model_name": "sentence-transformers/all-MiniLM-L6-v2"
'MiniLM': { # Example for a local Sentence Transformer model
'type': 'sentence_transformer',
'model_name': 'sentence-transformers/all-MiniLM-L6-v2',
},
"bge-m3": { # Example for a third-party API model
"type": "third_party_api",
"model_name": "bge-m3",
"api_endpoint": "https://api.qhaigc.net/v1/embeddings",
"headers": {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.getenv('rag_api_key')}"
},
"payload_template": {
"model": "bge-m3",
"input": ""
},
"embedding_dimension": 1024
'bge-m3': { # Example for a third-party API model
'type': 'third_party_api',
'model_name': 'bge-m3',
'api_endpoint': 'https://api.qhaigc.net/v1/embeddings',
'headers': {'Content-Type': 'application/json', 'Authorization': f'Bearer {os.getenv("rag_api_key")}'},
'payload_template': {'model': 'bge-m3', 'input': ''},
'embedding_dimension': 1024,
},
"OpenAI-Ada-002": {
"type": "third_party_api",
"model_name": "text-embedding-ada-002",
"api_endpoint": "https://api.openai.com/v1/embeddings",
"headers": {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}" # Ensure OPENAI_API_KEY is set
'OpenAI-Ada-002': {
'type': 'third_party_api',
'model_name': 'text-embedding-ada-002',
'api_endpoint': 'https://api.openai.com/v1/embeddings',
'headers': {
'Content-Type': 'application/json',
'Authorization': f'Bearer {os.getenv("OPENAI_API_KEY")}', # Ensure OPENAI_API_KEY is set
},
"payload_template": {
"model": "text-embedding-ada-002",
"input": "" # Text will be injected here
'payload_template': {
'model': 'text-embedding-ada-002',
'input': '', # Text will be injected here
},
"embedding_dimension": 1536
'embedding_dimension': 1536,
},
"OpenAI-Embedding-3-Small": {
"type": "third_party_api",
"model_name": "text-embedding-3-small",
"api_endpoint": "https://api.openai.com/v1/embeddings",
"headers": {
"Content-Type": "application/json",
"Authorization": f"Bearer {os.getenv('OPENAI_API_KEY')}"
},
"payload_template": {
"model": "text-embedding-3-small",
"input": "",
'OpenAI-Embedding-3-Small': {
'type': 'third_party_api',
'model_name': 'text-embedding-3-small',
'api_endpoint': 'https://api.openai.com/v1/embeddings',
'headers': {'Content-Type': 'application/json', 'Authorization': f'Bearer {os.getenv("OPENAI_API_KEY")}'},
'payload_template': {
'model': 'text-embedding-3-small',
'input': '',
# "dimensions": 512 # Optional: uncomment if you want a specific output dimension
},
"embedding_dimension": 1536 # Default max dimension for text-embedding-3-small
'embedding_dimension': 1536, # Default max dimension for text-embedding-3-small
},
}
}

View File

@@ -1,22 +1,21 @@
import PyPDF2
from docx import Document
import pandas as pd
import csv
import chardet
from typing import Union, List, Callable, Any
from typing import Union, Callable, Any
import logging
import markdown
from bs4 import BeautifulSoup
import ebooklib
from ebooklib import epub
import re
import asyncio # Import asyncio for async operations
import asyncio # Import asyncio for async operations
import os
# Configure logging
logger = logging.getLogger(__name__)
class FileParser:
"""
A robust file parser class to extract text content from various document formats.
@@ -24,8 +23,8 @@ class FileParser:
All core file reading operations are designed to be run synchronously in a thread pool
to avoid blocking the asyncio event loop.
"""
def __init__(self):
self.logger = logging.getLogger(self.__class__.__name__)
async def _run_sync(self, sync_func: Callable, *args: Any, **kwargs: Any) -> Any:
@@ -36,14 +35,14 @@ class FileParser:
try:
return await asyncio.to_thread(sync_func, *args, **kwargs)
except Exception as e:
self.logger.error(f"Error running synchronous function {sync_func.__name__}: {e}")
self.logger.error(f'Error running synchronous function {sync_func.__name__}: {e}')
raise
async def parse(self, file_path: str) -> Union[str, None]:
"""
Parses the file based on its extension and returns the extracted text content.
This is the main asynchronous entry point for parsing.
Args:
file_path (str): The path to the file to be parsed.
@@ -51,21 +50,21 @@ class FileParser:
Union[str, None]: The extracted text content as a single string, or None if parsing fails.
"""
if not file_path or not os.path.exists(file_path):
self.logger.error(f"Invalid file path provided: {file_path}")
self.logger.error(f'Invalid file path provided: {file_path}')
return None
file_extension = file_path.split('.')[-1].lower()
parser_method = getattr(self, f'_parse_{file_extension}', None)
if parser_method is None:
self.logger.error(f"Unsupported file format: {file_extension} for file {file_path}")
self.logger.error(f'Unsupported file format: {file_extension} for file {file_path}')
return None
try:
# Pass file_path to the specific parser methods
return await parser_method(file_path)
except Exception as e:
self.logger.error(f"Failed to parse {file_extension} file {file_path}: {e}")
self.logger.error(f'Failed to parse {file_extension} file {file_path}: {e}')
return None
# --- Helper for reading files with encoding detection ---
@@ -74,15 +73,16 @@ class FileParser:
Reads a file with automatic encoding detection, ensuring the synchronous
file read operation runs in a separate thread.
"""
def _read_sync():
with open(file_path, 'rb') as file:
raw_data = file.read()
detected = chardet.detect(raw_data)
encoding = detected['encoding'] or 'utf-8'
if mode == 'r':
return raw_data.decode(encoding, errors='ignore')
return raw_data # For binary mode
return raw_data # For binary mode
return await self._run_sync(_read_sync)
@@ -90,12 +90,13 @@ class FileParser:
async def _parse_txt(self, file_path: str) -> str:
"""Parses a TXT file and returns its content."""
self.logger.info(f"Parsing TXT file: {file_path}")
self.logger.info(f'Parsing TXT file: {file_path}')
return await self._read_file_content(file_path, mode='r')
async def _parse_pdf(self, file_path: str) -> str:
"""Parses a PDF file and returns its text content."""
self.logger.info(f"Parsing PDF file: {file_path}")
self.logger.info(f'Parsing PDF file: {file_path}')
def _parse_pdf_sync():
text_content = []
with open(file_path, 'rb') as file:
@@ -105,57 +106,69 @@ class FileParser:
if text:
text_content.append(text)
return '\n'.join(text_content)
return await self._run_sync(_parse_pdf_sync)
async def _parse_docx(self, file_path: str) -> str:
"""Parses a DOCX file and returns its text content."""
self.logger.info(f"Parsing DOCX file: {file_path}")
self.logger.info(f'Parsing DOCX file: {file_path}')
def _parse_docx_sync():
doc = Document(file_path)
text_content = [paragraph.text for paragraph in doc.paragraphs if paragraph.text.strip()]
return '\n'.join(text_content)
return await self._run_sync(_parse_docx_sync)
async def _parse_doc(self, file_path: str) -> str:
"""Handles .doc files, explicitly stating lack of direct support."""
self.logger.warning(f"Direct .doc parsing is not supported for {file_path}. Please convert to .docx first.")
raise NotImplementedError("Direct .doc parsing not supported. Please convert to .docx first.")
self.logger.warning(f'Direct .doc parsing is not supported for {file_path}. Please convert to .docx first.')
raise NotImplementedError('Direct .doc parsing not supported. Please convert to .docx first.')
async def _parse_xlsx(self, file_path: str) -> str:
"""Parses an XLSX file, returning text from all sheets."""
self.logger.info(f"Parsing XLSX file: {file_path}")
self.logger.info(f'Parsing XLSX file: {file_path}')
def _parse_xlsx_sync():
excel_file = pd.ExcelFile(file_path)
all_sheet_content = []
for sheet_name in excel_file.sheet_names:
df = pd.read_excel(file_path, sheet_name=sheet_name)
sheet_text = f"--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n"
sheet_text = f'--- Sheet: {sheet_name} ---\n{df.to_string(index=False)}\n'
all_sheet_content.append(sheet_text)
return '\n'.join(all_sheet_content)
return await self._run_sync(_parse_xlsx_sync)
async def _parse_csv(self, file_path: str) -> str:
"""Parses a CSV file and returns its content as a string."""
self.logger.info(f"Parsing CSV file: {file_path}")
self.logger.info(f'Parsing CSV file: {file_path}')
def _parse_csv_sync():
# pd.read_csv can often detect encoding, but explicit detection is safer
raw_data = self._read_file_content(file_path, mode='rb') # Note: this will need to be await outside this sync function
raw_data = self._read_file_content(
file_path, mode='rb'
) # Note: this will need to be await outside this sync function
_ = raw_data
# For simplicity, we'll let pandas handle encoding internally after a raw read.
# A more robust solution might pass encoding directly to pd.read_csv after detection.
detected = chardet.detect(open(file_path, 'rb').read())
encoding = detected['encoding'] or 'utf-8'
df = pd.read_csv(file_path, encoding=encoding)
return df.to_string(index=False)
return await self._run_sync(_parse_csv_sync)
async def _parse_markdown(self, file_path: str) -> str:
"""Parses a Markdown file, converting it to structured plain text."""
self.logger.info(f"Parsing Markdown file: {file_path}")
self.logger.info(f'Parsing Markdown file: {file_path}')
def _parse_markdown_sync():
md_content = self._read_file_content(file_path, mode='r') # This is a synchronous call within a sync function
md_content = self._read_file_content(
file_path, mode='r'
) # This is a synchronous call within a sync function
html_content = markdown.markdown(
md_content,
extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code']
md_content, extensions=['extra', 'codehilite', 'tables', 'toc', 'fenced_code']
)
soup = BeautifulSoup(html_content, 'html.parser')
text_parts = []
@@ -169,13 +182,13 @@ class FileParser:
text_parts.append(text)
elif element.name in ['ul', 'ol']:
for li in element.find_all('li'):
text_parts.append(f"* {li.get_text().strip()}")
text_parts.append(f'* {li.get_text().strip()}')
elif element.name == 'pre':
code_block = element.get_text().strip()
if code_block:
text_parts.append(f"```\n{code_block}\n```")
text_parts.append(f'```\n{code_block}\n```')
elif element.name == 'table':
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
if table_str:
text_parts.append(table_str)
elif element.name:
@@ -184,15 +197,17 @@ class FileParser:
text_parts.append(text)
cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts))
return cleaned_text.strip()
return await self._run_sync(_parse_markdown_sync)
async def _parse_html(self, file_path: str) -> str:
"""Parses an HTML file, extracting structured plain text."""
self.logger.info(f"Parsing HTML file: {file_path}")
self.logger.info(f'Parsing HTML file: {file_path}')
def _parse_html_sync():
html_content = self._read_file_content(file_path, mode='r') # Sync call within sync function
html_content = self._read_file_content(file_path, mode='r') # Sync call within sync function
soup = BeautifulSoup(html_content, 'html.parser')
for script_or_style in soup(["script", "style"]):
for script_or_style in soup(['script', 'style']):
script_or_style.decompose()
text_parts = []
for element in soup.body.children if soup.body else soup.children:
@@ -207,9 +222,9 @@ class FileParser:
for li in element.find_all('li'):
text = li.get_text().strip()
if text:
text_parts.append(f"* {text}")
text_parts.append(f'* {text}')
elif element.name == 'table':
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
table_str = self._extract_table_to_markdown_sync(element) # Call sync helper
if table_str:
text_parts.append(table_str)
elif element.name:
@@ -218,39 +233,42 @@ class FileParser:
text_parts.append(text)
cleaned_text = re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_parts))
return cleaned_text.strip()
return await self._run_sync(_parse_html_sync)
async def _parse_epub(self, file_path: str) -> str:
"""Parses an EPUB file, extracting metadata and content."""
self.logger.info(f"Parsing EPUB file: {file_path}")
self.logger.info(f'Parsing EPUB file: {file_path}')
def _parse_epub_sync():
book = epub.read_epub(file_path)
text_content = []
title_meta = book.get_metadata('DC', 'title')
if title_meta:
text_content.append(f"Title: {title_meta[0][0]}")
text_content.append(f'Title: {title_meta[0][0]}')
creator_meta = book.get_metadata('DC', 'creator')
if creator_meta:
text_content.append(f"Author: {creator_meta[0][0]}")
text_content.append(f'Author: {creator_meta[0][0]}')
date_meta = book.get_metadata('DC', 'date')
if date_meta:
text_content.append(f"Publish Date: {date_meta[0][0]}")
text_content.append(f'Publish Date: {date_meta[0][0]}')
toc = book.get_toc()
if toc:
text_content.append("\n--- Table of Contents ---")
self._add_toc_items_sync(toc, text_content, level=0) # Call sync helper
text_content.append("--- End of Table of Contents ---\n")
text_content.append('\n--- Table of Contents ---')
self._add_toc_items_sync(toc, text_content, level=0) # Call sync helper
text_content.append('--- End of Table of Contents ---\n')
for item in book.get_items():
if item.get_type() == ebooklib.ITEM_DOCUMENT:
html_content = item.get_content().decode('utf-8', errors='ignore')
soup = BeautifulSoup(html_content, 'html.parser')
for junk in soup(["script", "style", "nav", "header", "footer"]):
for junk in soup(['script', 'style', 'nav', 'header', 'footer']):
junk.decompose()
text = soup.get_text(separator='\n', strip=True)
text = re.sub(r'\n\s*\n', '\n\n', text)
if text:
text_content.append(text)
return re.sub(r'\n\s*\n', '\n\n', '\n'.join(text_content)).strip()
return await self._run_sync(_parse_epub_sync)
def _add_toc_items_sync(self, toc_list: list, text_content: list, level: int):
@@ -259,10 +277,10 @@ class FileParser:
for item in toc_list:
if isinstance(item, tuple):
chapter, subchapters = item
text_content.append(f"{indent}- {chapter.title}")
text_content.append(f'{indent}- {chapter.title}')
self._add_toc_items_sync(subchapters, text_content, level + 1)
else:
text_content.append(f"{indent}- {item.title}")
text_content.append(f'{indent}- {item.title}')
def _extract_table_to_markdown_sync(self, table_element: BeautifulSoup) -> str:
"""Helper to convert a BeautifulSoup table element into a Markdown table string (synchronous)."""
@@ -272,17 +290,17 @@ class FileParser:
cells = [td.get_text().strip() for td in tr.find_all('td')]
if cells:
rows.append(cells)
if not headers and not rows:
return ""
return ''
table_lines = []
if headers:
table_lines.append(' | '.join(headers))
table_lines.append(' | '.join(['---'] * len(headers)))
for row_cells in rows:
padded_cells = row_cells + [''] * (len(headers) - len(row_cells)) if headers else row_cells
table_lines.append(' | '.join(padded_cells))
return '\n'.join(table_lines)
return '\n'.join(table_lines)

View File

@@ -1,7 +1,6 @@
# services/retriever.py
import asyncio
import logging
import numpy as np # Make sure numpy is imported
import numpy as np # Make sure numpy is imported
from typing import List, Dict, Any
from sqlalchemy.orm import Session
from pkg.rag.knowledge.services.base_service import BaseService
@@ -11,6 +10,7 @@ from pkg.rag.knowledge.services.chroma_manager import ChromaIndexManager
logger = logging.getLogger(__name__)
class Retriever(BaseService):
def __init__(self, model_type: str, model_name_key: str, chroma_manager: ChromaIndexManager):
super().__init__()
@@ -22,10 +22,14 @@ class Retriever(BaseService):
self.embedding_model: BaseEmbeddingModel = self._load_embedding_model()
def _load_embedding_model(self) -> BaseEmbeddingModel:
self.logger.info(f"Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...")
self.logger.info(
f'Loading retriever embedding model: type={self.model_type}, name_key={self.model_name_key}...'
)
try:
model = EmbeddingModelFactory.create_model(self.model_type, self.model_name_key)
self.logger.info(f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}")
self.logger.info(
f"Retriever embedding model '{self.model_name_key}' loaded. Output dimension: {model.embedding_dimension}"
)
return model
except Exception as e:
self.logger.error(f"Failed to load retriever embedding model '{self.model_name_key}': {e}")
@@ -33,43 +37,42 @@ class Retriever(BaseService):
async def retrieve(self, query: str, k: int = 5) -> List[Dict[str, Any]]:
if not self.embedding_model:
raise RuntimeError("Retriever embedding model not loaded. Please check Retriever initialization.")
raise RuntimeError('Retriever embedding model not loaded. Please check Retriever initialization.')
self.logger.info(f"Retrieving for query: '{query}' with k={k} using {self.model_name_key}")
query_embedding: List[float] = await self.embedding_model.embed_query(query)
query_embedding_np = np.array([query_embedding], dtype=np.float32)
chroma_results = await self._run_sync(
self.chroma_manager.search_sync,
query_embedding_np, k
)
chroma_results = await self._run_sync(self.chroma_manager.search_sync, query_embedding_np, k)
# 'ids' is always returned by ChromaDB, even if not explicitly in 'include'
matched_chroma_ids = chroma_results.get("ids", [[]])[0]
distances = chroma_results.get("distances", [[]])[0]
chroma_metadatas = chroma_results.get("metadatas", [[]])[0]
chroma_documents = chroma_results.get("documents", [[]])[0]
matched_chroma_ids = chroma_results.get('ids', [[]])[0]
distances = chroma_results.get('distances', [[]])[0]
chroma_metadatas = chroma_results.get('metadatas', [[]])[0]
chroma_documents = chroma_results.get('documents', [[]])[0]
if not matched_chroma_ids:
self.logger.info("No relevant chunks found in Chroma.")
self.logger.info('No relevant chunks found in Chroma.')
return []
db_chunk_ids = []
for metadata in chroma_metadatas:
if "chunk_id" in metadata:
db_chunk_ids.append(metadata["chunk_id"])
if 'chunk_id' in metadata:
db_chunk_ids.append(metadata['chunk_id'])
else:
self.logger.warning(f"Metadata missing 'chunk_id': {metadata}. Skipping this entry.")
if not db_chunk_ids:
self.logger.warning("No valid chunk_ids extracted from Chroma results metadata.")
self.logger.warning('No valid chunk_ids extracted from Chroma results metadata.')
return []
self.logger.info(f"Fetching {len(db_chunk_ids)} chunk details from relational database...")
self.logger.info(f'Fetching {len(db_chunk_ids)} chunk details from relational database...')
chunks_from_db = await self._run_sync(
lambda cids: self._db_get_chunks_sync(SessionLocal(), cids), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync
db_chunk_ids
lambda cids: self._db_get_chunks_sync(
SessionLocal(), cids
), # Ensure SessionLocal is passed correctly for _db_get_chunks_sync
db_chunk_ids,
)
chunk_map = {chunk.id: chunk for chunk in chunks_from_db}
@@ -80,27 +83,29 @@ class Retriever(BaseService):
# Ensure original_chunk_id is int for DB lookup
original_chunk_id = int(chroma_id.split('_')[-1])
except (ValueError, IndexError):
self.logger.warning(f"Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.")
self.logger.warning(f'Could not parse chunk_id from Chroma ID: {chroma_id}. Skipping.')
continue
chunk_text_from_chroma = chroma_documents[i]
distance = float(distances[i])
file_id_from_chroma = chroma_metadatas[i].get("file_id")
file_id_from_chroma = chroma_metadatas[i].get('file_id')
chunk_from_db = chunk_map.get(original_chunk_id)
results_list.append({
"chunk_id": original_chunk_id,
"text": chunk_from_db.text if chunk_from_db else chunk_text_from_chroma,
"distance": distance,
"file_id": file_id_from_chroma
})
results_list.append(
{
'chunk_id': original_chunk_id,
'text': chunk_from_db.text if chunk_from_db else chunk_text_from_chroma,
'distance': distance,
'file_id': file_id_from_chroma,
}
)
self.logger.info(f"Retrieved {len(results_list)} chunks.")
self.logger.info(f'Retrieved {len(results_list)} chunks.')
return results_list
def _db_get_chunks_sync(self, session: Session, chunk_ids: List[int]) -> List[Chunk]:
self.logger.debug(f"Fetching {len(chunk_ids)} chunk details from database (sync).")
self.logger.debug(f'Fetching {len(chunk_ids)} chunk details from database (sync).')
chunks = session.query(Chunk).filter(Chunk.id.in_(chunk_ids)).all()
session.close()
return chunks
return chunks