mirror of
https://github.com/langbot-app/LangBot.git
synced 2025-11-25 11:29:39 +08:00
fix: delete embedding models file
This commit is contained in:
@@ -1,238 +0,0 @@
|
||||
# services/embedding_models.py
|
||||
|
||||
import os
|
||||
from typing import Dict, Any, List
|
||||
import logging
|
||||
import aiohttp # Import aiohttp for asynchronous requests
|
||||
import asyncio
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# Base class for all embedding models
|
||||
class BaseEmbeddingModel:
|
||||
def __init__(self, model_name: str):
|
||||
self.model_name = model_name
|
||||
self._embedding_dimension = None
|
||||
|
||||
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronously embeds a list of texts."""
|
||||
raise NotImplementedError
|
||||
|
||||
async def embed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronously embeds a single query text."""
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def embedding_dimension(self) -> int:
|
||||
"""Returns the embedding dimension of the model."""
|
||||
if self._embedding_dimension is None:
|
||||
raise NotImplementedError('Embedding dimension not set for this model.')
|
||||
return self._embedding_dimension
|
||||
|
||||
|
||||
class EmbeddingModelFactory:
|
||||
@staticmethod
|
||||
def create_model(model_type: str, model_name_key: str) -> BaseEmbeddingModel:
|
||||
"""
|
||||
Factory method to create an embedding model instance.
|
||||
Currently only supports 'third_party_api' types.
|
||||
"""
|
||||
if model_name_key not in EMBEDDING_MODEL_CONFIGS:
|
||||
raise ValueError(f"Embedding model configuration '{model_name_key}' not found in EMBEDDING_MODEL_CONFIGS.")
|
||||
|
||||
config = EMBEDDING_MODEL_CONFIGS[model_name_key]
|
||||
|
||||
if config['type'] == 'third_party_api':
|
||||
required_keys = ['api_endpoint', 'headers', 'payload_template', 'embedding_dimension']
|
||||
if not all(key in config for key in required_keys):
|
||||
raise ValueError(
|
||||
f"Missing configuration keys for third_party_api model '{model_name_key}'. Required: {required_keys}"
|
||||
)
|
||||
|
||||
# Retrieve model_name from config if it differs from model_name_key
|
||||
# Some APIs expect a specific 'model' value in the payload that might be different from the key
|
||||
api_model_name = config.get('model_name', model_name_key)
|
||||
|
||||
return ThirdPartyAPIEmbeddingModel(
|
||||
model_name=api_model_name, # Use the model_name from config or the key
|
||||
api_endpoint=config['api_endpoint'],
|
||||
headers=config['headers'],
|
||||
payload_template=config['payload_template'],
|
||||
embedding_dimension=config['embedding_dimension'],
|
||||
)
|
||||
|
||||
|
||||
class SentenceTransformerEmbeddingModel(BaseEmbeddingModel):
|
||||
def __init__(self, model_name: str):
|
||||
super().__init__(model_name)
|
||||
try:
|
||||
# SentenceTransformer is inherently synchronous, but we'll wrap its calls
|
||||
# in async methods. The actual computation will still block the event loop
|
||||
# if not run in a separate thread/process, but this keeps the API consistent.
|
||||
self.model = SentenceTransformer(model_name)
|
||||
self._embedding_dimension = self.model.get_sentence_embedding_dimension()
|
||||
logger.info(
|
||||
f"Initialized SentenceTransformer model '{model_name}' with dimension {self._embedding_dimension}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f'Failed to load SentenceTransformer model {model_name}: {e}')
|
||||
raise
|
||||
|
||||
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
# For CPU-bound tasks like local model inference, consider running in a thread pool
|
||||
# to prevent blocking the event loop for long operations.
|
||||
# For simplicity here, we'll call it directly.
|
||||
return self.model.encode(texts).tolist()
|
||||
|
||||
async def embed_query(self, text: str) -> List[float]:
|
||||
return self.model.encode(text).tolist()
|
||||
|
||||
|
||||
class ThirdPartyAPIEmbeddingModel(BaseEmbeddingModel):
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_endpoint: str,
|
||||
headers: Dict[str, str],
|
||||
payload_template: Dict[str, Any],
|
||||
embedding_dimension: int,
|
||||
):
|
||||
super().__init__(model_name)
|
||||
self.api_endpoint = api_endpoint
|
||||
self.headers = headers
|
||||
self.payload_template = payload_template
|
||||
self._embedding_dimension = embedding_dimension
|
||||
self.session = None # aiohttp client session will be initialized on first use or in a context manager
|
||||
logger.info(
|
||||
f"Initialized ThirdPartyAPIEmbeddingModel '{model_name}' for async calls to {api_endpoint} with dimension {embedding_dimension}"
|
||||
)
|
||||
|
||||
async def _get_session(self):
|
||||
"""Lazily create or return the aiohttp client session."""
|
||||
if self.session is None or self.session.closed:
|
||||
self.session = aiohttp.ClientSession()
|
||||
return self.session
|
||||
|
||||
async def close_session(self):
|
||||
"""Explicitly close the aiohttp client session."""
|
||||
if self.session and not self.session.closed:
|
||||
await self.session.close()
|
||||
self.session = None
|
||||
logger.info(f'Closed aiohttp session for model {self.model_name}')
|
||||
|
||||
async def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Asynchronously embeds a list of texts using the third-party API."""
|
||||
session = await self._get_session()
|
||||
embeddings = []
|
||||
tasks = []
|
||||
for text in texts:
|
||||
payload = self.payload_template.copy()
|
||||
if 'input' in payload:
|
||||
payload['input'] = text
|
||||
elif 'texts' in payload:
|
||||
payload['texts'] = [text]
|
||||
else:
|
||||
raise ValueError('Payload template does not contain expected text input key.')
|
||||
|
||||
tasks.append(self._make_api_request(session, payload))
|
||||
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
for i, res in enumerate(results):
|
||||
if isinstance(res, Exception):
|
||||
logger.error(f"Error embedding text '{texts[i][:50]}...': {res}")
|
||||
# Depending on your error handling strategy, you might:
|
||||
# - Append None or an empty list
|
||||
# - Re-raise the exception to stop processing
|
||||
# - Log and skip, then continue
|
||||
embeddings.append([0.0] * self.embedding_dimension) # Append dummy embedding or handle failure
|
||||
else:
|
||||
embeddings.append(res)
|
||||
|
||||
return embeddings
|
||||
|
||||
async def _make_api_request(self, session: aiohttp.ClientSession, payload: Dict[str, Any]) -> List[float]:
|
||||
"""Helper to make an asynchronous API request and extract embedding."""
|
||||
try:
|
||||
async with session.post(self.api_endpoint, headers=self.headers, json=payload) as response:
|
||||
response.raise_for_status() # Raise an exception for HTTP errors (4xx, 5xx)
|
||||
api_response = await response.json()
|
||||
|
||||
# Adjust this based on your API's actual response structure
|
||||
if 'data' in api_response and len(api_response['data']) > 0 and 'embedding' in api_response['data'][0]:
|
||||
embedding = api_response['data'][0]['embedding']
|
||||
if len(embedding) != self.embedding_dimension:
|
||||
logger.warning(
|
||||
f'API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.'
|
||||
)
|
||||
return embedding
|
||||
elif (
|
||||
'embeddings' in api_response
|
||||
and isinstance(api_response['embeddings'], list)
|
||||
and api_response['embeddings']
|
||||
):
|
||||
embedding = api_response['embeddings'][0]
|
||||
if len(embedding) != self.embedding_dimension:
|
||||
logger.warning(
|
||||
f'API returned embedding of dimension {len(embedding)}, but expected {self.embedding_dimension} for model {self.model_name}. Adjusting config might be needed.'
|
||||
)
|
||||
return embedding
|
||||
else:
|
||||
raise ValueError(f'Unexpected API response structure: {api_response}')
|
||||
|
||||
except aiohttp.ClientError as e:
|
||||
raise ConnectionError(f'API request failed: {e}') from e
|
||||
except ValueError as e:
|
||||
raise ValueError(f'Error processing API response: {e}') from e
|
||||
|
||||
async def embed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronously embeds a single query text."""
|
||||
results = await self.embed_documents([text])
|
||||
if results:
|
||||
return results[0]
|
||||
return [] # Or raise an error if embedding a query must always succeed
|
||||
|
||||
|
||||
# --- Embedding Model Configuration ---
|
||||
EMBEDDING_MODEL_CONFIGS: Dict[str, Dict[str, Any]] = {
|
||||
'MiniLM': { # Example for a local Sentence Transformer model
|
||||
'type': 'sentence_transformer',
|
||||
'model_name': 'sentence-transformers/all-MiniLM-L6-v2',
|
||||
},
|
||||
'bge-m3': { # Example for a third-party API model
|
||||
'type': 'third_party_api',
|
||||
'model_name': 'bge-m3',
|
||||
'api_endpoint': 'https://api.qhaigc.net/v1/embeddings',
|
||||
'headers': {'Content-Type': 'application/json', 'Authorization': f'Bearer {os.getenv("rag_api_key")}'},
|
||||
'payload_template': {'model': 'bge-m3', 'input': ''},
|
||||
'embedding_dimension': 1024,
|
||||
},
|
||||
'OpenAI-Ada-002': {
|
||||
'type': 'third_party_api',
|
||||
'model_name': 'text-embedding-ada-002',
|
||||
'api_endpoint': 'https://api.openai.com/v1/embeddings',
|
||||
'headers': {
|
||||
'Content-Type': 'application/json',
|
||||
'Authorization': f'Bearer {os.getenv("OPENAI_API_KEY")}', # Ensure OPENAI_API_KEY is set
|
||||
},
|
||||
'payload_template': {
|
||||
'model': 'text-embedding-ada-002',
|
||||
'input': '', # Text will be injected here
|
||||
},
|
||||
'embedding_dimension': 1536,
|
||||
},
|
||||
'OpenAI-Embedding-3-Small': {
|
||||
'type': 'third_party_api',
|
||||
'model_name': 'text-embedding-3-small',
|
||||
'api_endpoint': 'https://api.openai.com/v1/embeddings',
|
||||
'headers': {'Content-Type': 'application/json', 'Authorization': f'Bearer {os.getenv("OPENAI_API_KEY")}'},
|
||||
'payload_template': {
|
||||
'model': 'text-embedding-3-small',
|
||||
'input': '',
|
||||
# "dimensions": 512 # Optional: uncomment if you want a specific output dimension
|
||||
},
|
||||
'embedding_dimension': 1536, # Default max dimension for text-embedding-3-small
|
||||
},
|
||||
}
|
||||
Reference in New Issue
Block a user